#!/usr/bin/env python3
"""
Memory V3 - Pre-trained Sentence Embeddings
Using Sentence-BERT for REAL semantic understanding
"""

import torch
import torch.nn.functional as F
import numpy as np
from sentence_transformers import SentenceTransformer
from datetime import datetime, timedelta
import random
from tqdm import tqdm

# =============================================================================
# MEMORY WITH PRE-TRAINED EMBEDDINGS
# =============================================================================

class Episode:
    def __init__(self, content, timestamp, category=None, embedding=None):
        self.content = content
        self.timestamp = timestamp
        self.category = category
        self.embedding = embedding
        self.id = id(self)

class MemoryStore:
    def __init__(self, max_size=10000):
        self.memories = []
        self.max_size = max_size
    
    def add(self, episode):
        self.memories.append(episode)
        if len(self.memories) > self.max_size:
            self.memories.pop(0)
    
    def get_all(self):
        return self.memories
    
    def get_recent(self, n=10):
        return self.memories[-n:]
    
    def size(self):
        return len(self.memories)

class PretrainedMemorySystem:
    """
    Memory system using pre-trained sentence embeddings
    This WILL work for semantic search
    """
    
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        print(f"Loading pre-trained model: {model_name}...")
        # Small, fast, high-quality model (22M params, 384 dims)
        self.encoder = SentenceTransformer(model_name, device=self.device)
        print("✅ Model loaded!")
        
        self.store = MemoryStore()
    
    def store_experience(self, content, timestamp=None, category=None):
        """Store experience with pre-trained embedding"""
        if timestamp is None:
            timestamp = datetime.now()
        
        # Encode with pre-trained model
        embedding = self.encoder.encode(content, convert_to_tensor=True)
        
        episode = Episode(
            content=content,
            timestamp=timestamp,
            category=category,
            embedding=embedding.cpu()
        )
        
        self.store.add(episode)
        return episode
    
    def retrieve(self, query, top_k=5):
        """Retrieve by semantic similarity"""
        if self.store.size() == 0:
            return []
        
        # Encode query
        query_emb = self.encoder.encode(query, convert_to_tensor=True).cpu()
        
        # Calculate cosine similarities
        scores = []
        for memory in self.store.get_all():
            sim = F.cosine_similarity(
                query_emb.unsqueeze(0),
                memory.embedding.unsqueeze(0)
            ).item()
            scores.append(sim)
        
        # Get top-k
        scores = np.array(scores)
        top_indices = np.argsort(scores)[-top_k:][::-1]
        
        results = []
        memories = self.store.get_all()
        for idx in top_indices:
            results.append((memories[idx], float(scores[idx])))
        
        return results
    
    def retrieve_recent(self, n=5):
        """Get recent memories"""
        return self.store.get_recent(n)
    
    def retrieve_by_category(self, category):
        """Get all memories of a category"""
        return [m for m in self.store.get_all() if m.category == category]

# =============================================================================
# TESTING
# =============================================================================

def test_pretrained_memory():
    """Test with pre-trained embeddings"""
    print("\n" + "="*70)
    print("TESTING MEMORY V3 (PRE-TRAINED EMBEDDINGS)")
    print("="*70)
    
    print("\nInitializing memory system...")
    memory = PretrainedMemorySystem(model_name='all-MiniLM-L6-v2')
    
    print("\nStoring diverse experiences...")
    
    test_experiences = [
        # Food
        ("went to coffee shop for morning latte", "food"),
        ("bought groceries including milk bread and eggs", "food"),
        ("had delicious lunch at italian restaurant with pasta", "food"),
        ("cooked dinner with chicken and vegetables", "food"),
        
        # Work
        ("attended team meeting about project deadline", "work"),
        ("fixed critical bug in production code", "work"),
        ("reviewed code from colleague pull request", "work"),
        ("deployed new feature to cloud", "work"),
        
        # Reading
        ("read research paper about transformers architecture", "reading"),
        ("studied book chapter on machine learning algorithms", "reading"),
        ("read article about latest AI developments", "reading"),
        
        # Social
        ("called friend to catch up on life", "social"),
        ("went to party with colleagues from work", "social"),
        
        # Exercise
        ("went for morning run in the park", "exercise"),
        ("did yoga session at home", "exercise")
    ]
    
    base_time = datetime.now() - timedelta(days=10)
    
    for i, (content, category) in enumerate(test_experiences):
        timestamp = base_time + timedelta(days=i/2)
        memory.store_experience(content, timestamp, category)
        print(f"  ✓ [{category:8s}] {content[:50]}")
    
    print(f"\n✅ Stored {memory.store.size()} memories")
    
    # Test semantic retrieval
    print("\n" + "="*70)
    print("TEST: SEMANTIC RETRIEVAL")
    print("="*70)
    
    test_queries = [
        ("breakfast coffee drink", "food", "coffee"),
        ("programming coding software", "work", "code"),
        ("scientific papers research", "reading", "paper"),
        ("restaurant eating meal", "food", "lunch"),
        ("project team collaboration", "work", "meeting"),
        ("friends social gathering", "social", "party"),
        ("running jogging fitness", "exercise", "run")
    ]
    
    passed = 0
    total = len(test_queries)
    
    for query, expected_cat, expected_keyword in test_queries:
        print(f"\nQuery: '{query}'")
        print(f"Expected category: {expected_cat}")
        
        results = memory.retrieve(query, top_k=3)
        
        print("Top 3 matches:")
        for i, (ep, score) in enumerate(results, 1):
            marker = "✓" if ep.category == expected_cat else "✗"
            print(f"  {marker} {i}. [{ep.category:8s}] (score: {score:.3f}) {ep.content[:55]}")
        
        # Check if top result is correct category
        if results and results[0][0].category == expected_cat:
            print(f"  ✅ CORRECT! Top match is {expected_cat}")
            passed += 1
        else:
            print(f"  ❌ WRONG - Top match is {results[0][0].category if results else 'none'}")
    
    # Test recency
    print("\n" + "="*70)
    print("TEST: RECENT MEMORIES")
    print("="*70)
    
    recent = memory.retrieve_recent(n=3)
    print("\nMost recent 3 memories:")
    for i, mem in enumerate(recent, 1):
        print(f"  {i}. {mem.content}")
        print(f"     {mem.timestamp.strftime('%Y-%m-%d %H:%M')}")
    
    if len(recent) == 3:
        print("  ✅ Recency retrieval working!")
        passed += 1
        total += 1
    
    # Final results
    print("\n" + "="*70)
    print("FINAL RESULTS")
    print("="*70)
    
    accuracy = 100 * passed / total
    print(f"\nScore: {passed}/{total} ({accuracy:.1f}%)")
    
    if passed >= total * 0.85:
        print("\n✅ EXCELLENT - Semantic search WORKING!")
        return True
    elif passed >= total * 0.7:
        print("\n✅ GOOD - Strong semantic understanding!")
        return True
    elif passed >= total * 0.5:
        print("\n⚠️ DECENT - Partial success")
        return True
    else:
        print("\n❌ Still struggling")
        return False

def demo_memory_capabilities():
    """Quick demo of capabilities"""
    print("\n" + "="*70)
    print("DEMONSTRATION: REAL-WORLD USAGE")
    print("="*70)
    
    memory = PretrainedMemorySystem()
    
    # Store real experiences
    experiences = [
        "Had coffee with Sarah to discuss the new marketing campaign",
        "Fixed the authentication bug that was blocking users",
        "Read paper on attention mechanisms in neural networks",
        "Bought ingredients for tonight's dinner party",
        "Went for a 5km run along the river",
    ]
    
    print("\nStoring experiences...")
    for exp in experiences:
        memory.store_experience(exp)
        print(f"  ✓ {exp}")
    
    # Query examples
    print("\n" + "="*70)
    print("Example queries:")
    
    queries = [
        "what did I do related to food?",
        "any technical work I did?",
        "tell me about my exercise",
    ]
    
    for query in queries:
        print(f"\nQ: {query}")
        results = memory.retrieve(query, top_k=2)
        for i, (ep, score) in enumerate(results, 1):
            print(f"  {i}. {ep.content} (relevance: {score:.2f})")

def main():
    # Test
    success = test_pretrained_memory()
    
    # Demo
    if success:
        demo_memory_capabilities()
    
    print("\n" + "="*70)
    if success:
        print("✅ EPISODIC MEMORY V3: WORKING!")
        print("\nCapabilities:")
        print("  1. Store experiences with timestamps")
        print("  2. ✅ SEMANTIC RETRIEVAL (pre-trained)")
        print("  3. Retrieve by recency")
        print("  4. Retrieve by category")
        print("\n✅ CAPABILITY #7 COMPLETE")
        print("\n🎯 Using state-of-art Sentence-BERT embeddings!")
    
    # Save for integration
    print("\nMemory system ready for Eden integration!")

if __name__ == "__main__":
    main()
