#!/usr/bin/env python3
"""
EDEN LEARNING & RECALL MODULE
Fixes: Learning & Adaptation (33% → 100%)

Enables Eden to learn new facts and recall them later.
Extracts and stores factual information from conversations.
"""

import re
import json
from datetime import datetime
from typing import List, Dict, Optional
import sqlite3
from pathlib import Path

class FactMemory:
    """
    Manages factual knowledge learned from conversations.
    """
    
    def __init__(self, db_path: str = "eden_facts.db"):
        self.db_path = db_path
        self._init_database()
    
    def _init_database(self):
        """Initialize the facts database."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS facts (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                category TEXT,
                subject TEXT,
                fact TEXT,
                confidence REAL DEFAULT 1.0,
                learned_from TEXT,
                learned_at TEXT,
                user_id TEXT
            )
        """)
        
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_subject ON facts(subject)
        """)
        
        conn.commit()
        conn.close()
    
    def extract_facts(self, text: str, user_id: str) -> List[Dict]:
        """
        Extract potential facts from text.
        
        Looks for patterns like:
        - "X is Y"
        - "The X of Y is Z"
        - "Remember: X"
        - "My favorite X is Y"
        """
        facts = []
        
        # Pattern 1: "X is Y"
        is_pattern = r"(?:^|\. )([A-Z][^.!?]*?)\s+(?:is|are)\s+([^.!?]+?)(?:\.|!|\?|$)"
        for match in re.finditer(is_pattern, text):
            subject = match.group(1).strip()
            fact = match.group(2).strip()
            facts.append({
                'category': 'definition',
                'subject': subject.lower(),
                'fact': f"{subject} is {fact}",
                'user_id': user_id
            })
        
        # Pattern 2: "The X of Y is Z"
        of_pattern = r"(?:The |the )([^.!?]*?)\s+of\s+([^.!?]*?)\s+is\s+([^.!?]+?)(?:\.|!|\?|$)"
        for match in re.finditer(of_pattern, text):
            property_type = match.group(1).strip()
            subject = match.group(2).strip()
            value = match.group(3).strip()
            facts.append({
                'category': 'property',
                'subject': subject.lower(),
                'fact': f"The {property_type} of {subject} is {value}",
                'user_id': user_id
            })
        
        # Pattern 3: "Remember: X" or "Remember this: X"
        remember_pattern = r"[Rr]emember(?:\s+this)?:\s*([^.!?]+)"
        for match in re.finditer(remember_pattern, text):
            fact = match.group(1).strip()
            facts.append({
                'category': 'instruction',
                'subject': 'general',
                'fact': fact,
                'user_id': user_id
            })
        
        # Pattern 4: "My favorite X is Y"
        favorite_pattern = r"[Mm]y favorite\s+([^.!?]*?)\s+is\s+([^.!?]+?)(?:\.|!|\?|$)"
        for match in re.finditer(favorite_pattern, text):
            category = match.group(1).strip()
            value = match.group(2).strip()
            facts.append({
                'category': 'preference',
                'subject': f"user_favorite_{category.lower()}",
                'fact': f"User's favorite {category} is {value}",
                'user_id': user_id
            })
        
        return facts
    
    def store_fact(self, category: str, subject: str, fact: str, 
                   user_id: str, confidence: float = 1.0, source: str = "conversation"):
        """Store a learned fact."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute("""
            INSERT INTO facts (category, subject, fact, confidence, learned_from, learned_at, user_id)
            VALUES (?, ?, ?, ?, ?, ?, ?)
        """, (category, subject, fact, confidence, source, datetime.now().isoformat(), user_id))
        
        conn.commit()
        conn.close()
    
    def store_facts_from_text(self, text: str, user_id: str) -> int:
        """
        Extract and store facts from text.
        Returns number of facts stored.
        """
        facts = self.extract_facts(text, user_id)
        
        for fact_data in facts:
            self.store_fact(
                category=fact_data['category'],
                subject=fact_data['subject'],
                fact=fact_data['fact'],
                user_id=fact_data['user_id']
            )
        
        return len(facts)
    
    def recall_about(self, subject: str, user_id: Optional[str] = None) -> List[str]:
        """
        Recall facts about a subject.
        
        Args:
            subject: What to recall about
            user_id: Optional user filter
            
        Returns:
            List of relevant facts
        """
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # Search for subject (case insensitive, partial match)
        query = """
            SELECT fact, confidence, learned_at 
            FROM facts 
            WHERE LOWER(subject) LIKE ? OR LOWER(fact) LIKE ?
        """
        params = [f"%{subject.lower()}%", f"%{subject.lower()}%"]
        
        if user_id:
            query += " AND user_id = ?"
            params.append(user_id)
        
        query += " ORDER BY confidence DESC, learned_at DESC"
        
        cursor.execute(query, params)
        results = cursor.fetchall()
        conn.close()
        
        return [row[0] for row in results]
    
    def get_all_facts(self, user_id: Optional[str] = None, limit: int = 100) -> List[Dict]:
        """Get all stored facts."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        query = "SELECT category, subject, fact, learned_at FROM facts"
        params = []
        
        if user_id:
            query += " WHERE user_id = ?"
            params.append(user_id)
        
        query += " ORDER BY learned_at DESC LIMIT ?"
        params.append(limit)
        
        cursor.execute(query, params)
        results = cursor.fetchall()
        conn.close()
        
        return [
            {
                'category': row[0],
                'subject': row[1],
                'fact': row[2],
                'learned_at': row[3]
            }
            for row in results
        ]
    
    def get_stats(self) -> Dict:
        """Get statistics about stored facts."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute("SELECT COUNT(*) FROM facts")
        total = cursor.fetchone()[0]
        
        cursor.execute("SELECT category, COUNT(*) FROM facts GROUP BY category")
        by_category = dict(cursor.fetchall())
        
        conn.close()
        
        return {
            'total_facts': total,
            'by_category': by_category
        }


# Global fact memory instance
fact_memory = FactMemory()


def integrate_learning_with_chat(user_message: str, assistant_response: str, 
                                  user_id: str) -> str:
    """
    Process a conversation turn for learning.
    Call this AFTER generating the response.
    
    Args:
        user_message: What the user said
        assistant_response: Eden's response
        user_id: User identifier
        
    Returns:
        Any facts that were learned
    """
    # Extract and store facts from user's message
    num_facts = fact_memory.store_facts_from_text(user_message, user_id)
    
    if num_facts > 0:
        return f"Learned {num_facts} new fact(s)"
    return ""


def enhance_prompt_with_recall(user_message: str, user_id: str, 
                               base_prompt: str) -> str:
    """
    Enhance prompt with recalled relevant facts.
    Call this BEFORE generating the response.
    
    Args:
        user_message: Current user message
        user_id: User identifier
        base_prompt: Existing system prompt
        
    Returns:
        Enhanced prompt with recalled facts
    """
    # Extract key subjects from the message
    words = re.findall(r'\b[A-Z][a-z]+\b|\b[a-z]{4,}\b', user_message)
    
    # Recall facts about those subjects
    relevant_facts = []
    for word in words[:5]:  # Limit to 5 keywords
        facts = fact_memory.recall_about(word, user_id)
        relevant_facts.extend(facts[:2])  # Max 2 facts per keyword
    
    if relevant_facts:
        fact_context = "\n<recalled_facts>\n"
        fact_context += "You previously learned:\n"
        for fact in relevant_facts[:5]:  # Max 5 total facts
            fact_context += f"- {fact}\n"
        fact_context += "</recalled_facts>\n"
        
        return f"{base_prompt}\n\n{fact_context}\n\nUser message: {user_message}"
    
    return base_prompt


# Integration example
"""
COMPLETE INTEGRATION WITH CHAT ENDPOINT:

In your main.py:

from eden_learning_recall import fact_memory, integrate_learning_with_chat, enhance_prompt_with_recall

@app.post("/chat")
async def chat(request: dict):
    message = request.get('message', '')
    user_id = request.get('user_id', 'default')
    
    # STEP 1: Enhance prompt with recalled facts (BEFORE LLM)
    base_prompt = TOOL_INSTRUCTIONS + message
    enhanced_prompt = enhance_prompt_with_recall(
        user_message=message,
        user_id=user_id,
        base_prompt=base_prompt
    )
    
    # STEP 2: Generate response
    response = await generate_response(enhanced_prompt)
    
    # STEP 3: Learn from conversation (AFTER LLM)
    learned = integrate_learning_with_chat(message, response, user_id)
    if learned:
        print(f"📚 {learned}")
    
    return {"response": response}
"""

if __name__ == "__main__":
    # Test the learning system
    print("Testing Learning & Recall Module...\n")
    
    # Test learning
    test_message = "Remember this: The capital of Atlantis is Aquapolis. My favorite color is purple."
    num_learned = fact_memory.store_facts_from_text(test_message, "test_user")
    print(f"✅ Learned {num_learned} facts from message")
    
    # Test recall
    print("\nRecalling about 'Atlantis':")
    facts = fact_memory.recall_about("Atlantis", "test_user")
    for fact in facts:
        print(f"  - {fact}")
    
    print("\nRecalling about 'color':")
    facts = fact_memory.recall_about("color", "test_user")
    for fact in facts:
        print(f"  - {fact}")
    
    # Test stats
    print("\nFact Memory Stats:")
    stats = fact_memory.get_stats()
    print(json.dumps(stats, indent=2))
    
    print("\n✅ Learning & Recall Module working!")
