#!/usr/bin/env python3
"""
Memory V2 - With Contrastive Learning
Properly trained semantic embeddings
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from datetime import datetime, timedelta
import random

# =============================================================================
# MEMORY STORAGE
# =============================================================================

class Episode:
    def __init__(self, content, timestamp, category=None, embedding=None):
        self.content = content
        self.timestamp = timestamp
        self.category = category  # For training
        self.embedding = embedding
        self.id = id(self)

class MemoryStore:
    def __init__(self, max_size=10000):
        self.memories = []
        self.max_size = max_size
    
    def add(self, episode):
        self.memories.append(episode)
        if len(self.memories) > self.max_size:
            self.memories.pop(0)
    
    def get_all(self):
        return self.memories
    
    def get_recent(self, n=10):
        return self.memories[-n:]
    
    def get_by_category(self, category):
        return [m for m in self.memories if m.category == category]
    
    def size(self):
        return len(self.memories)

# =============================================================================
# IMPROVED ENCODER WITH CONTRASTIVE LEARNING
# =============================================================================

class ContrastiveMemoryEncoder(nn.Module):
    """
    Encoder trained with contrastive learning
    Similar experiences should have similar embeddings
    """
    
    def __init__(self, vocab_size=10000, embed_dim=256, hidden_dim=256):
        super().__init__()
        
        # Word embeddings
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # Bi-LSTM encoder
        self.encoder = nn.LSTM(
            embed_dim, hidden_dim, 
            num_layers=2, 
            batch_first=True, 
            bidirectional=True,
            dropout=0.2
        )
        
        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # L2 normalize for better similarity
        self.normalize = True
    
    def forward(self, text_indices):
        """
        Args:
            text_indices: [batch, seq_len]
        Returns:
            embeddings: [batch, hidden_dim] normalized vectors
        """
        # Embed
        embedded = self.embedding(text_indices)
        
        # Encode
        output, (hidden, _) = self.encoder(embedded)
        
        # Use last hidden states (forward + backward)
        hidden_cat = torch.cat([hidden[-2], hidden[-1]], dim=-1)
        
        # Project
        projected = self.projection(hidden_cat)
        
        # Normalize
        if self.normalize:
            projected = F.normalize(projected, p=2, dim=-1)
        
        return projected

# =============================================================================
# CONTRASTIVE LOSS
# =============================================================================

class ContrastiveLoss(nn.Module):
    """
    NT-Xent (Normalized Temperature-scaled Cross Entropy Loss)
    Used in SimCLR
    """
    
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, embeddings, labels):
        """
        Args:
            embeddings: [batch, hidden_dim] normalized
            labels: [batch] category labels
        """
        batch_size = embeddings.shape[0]
        
        # Compute similarity matrix
        similarity = torch.matmul(embeddings, embeddings.T) / self.temperature
        
        # Create positive/negative masks
        labels = labels.unsqueeze(1)
        mask_positive = (labels == labels.T).float()
        mask_positive.fill_diagonal_(0)  # Remove self-similarity
        
        # Mask out positives from negatives
        mask_negative = 1 - mask_positive
        mask_negative.fill_diagonal_(0)
        
        # For each sample, compute loss
        # Numerator: similarity to positives
        # Denominator: similarity to all (positives + negatives)
        
        # Simplified: use InfoNCE-style loss
        loss = 0
        n_positives = 0
        
        for i in range(batch_size):
            # Get positive samples for this anchor
            positives = mask_positive[i] > 0
            
            if positives.sum() == 0:
                continue
            
            # Positive similarities
            pos_sim = similarity[i, positives]
            
            # All similarities (excluding self)
            all_sim = similarity[i]
            all_sim = all_sim[torch.arange(batch_size) != i]
            
            # Log-sum-exp for numerical stability
            loss += -torch.log(
                torch.exp(pos_sim).sum() / torch.exp(all_sim).sum()
            )
            
            n_positives += 1
        
        if n_positives == 0:
            return torch.tensor(0.0, device=embeddings.device)
        
        return loss / n_positives

# =============================================================================
# IMPROVED MEMORY SYSTEM
# =============================================================================

class ImprovedMemorySystem:
    def __init__(self, vocab_size=10000, embed_dim=256, hidden_dim=256):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.store = MemoryStore()
        
        # Encoder
        self.encoder = ContrastiveMemoryEncoder(
            vocab_size, embed_dim, hidden_dim
        ).to(self.device)
        
        # Vocabulary
        self.vocab = {"<PAD>": 0, "<UNK>": 1}
        self.next_id = 2
    
    def _text_to_indices(self, text, max_len=50):
        words = text.lower().split()
        indices = []
        
        for word in words:
            if word not in self.vocab:
                if self.next_id < 10000:
                    self.vocab[word] = self.next_id
                    self.next_id += 1
                else:
                    indices.append(1)  # UNK
                    continue
            indices.append(self.vocab[word])
        
        if len(indices) < max_len:
            indices += [0] * (max_len - len(indices))
        else:
            indices = indices[:max_len]
        
        return torch.tensor(indices)
    
    def _encode_text(self, text):
        indices = self._text_to_indices(text).unsqueeze(0).to(self.device)
        with torch.no_grad():
            self.encoder.eval()
            embedding = self.encoder(indices)
        return embedding.squeeze(0)
    
    def store_experience(self, content, timestamp=None, category=None):
        if timestamp is None:
            timestamp = datetime.now()
        
        embedding = self._encode_text(content)
        
        episode = Episode(
            content=content,
            timestamp=timestamp,
            category=category,
            embedding=embedding.cpu()
        )
        
        self.store.add(episode)
        return episode
    
    def retrieve(self, query, top_k=5):
        if self.store.size() == 0:
            return []
        
        query_emb = self._encode_text(query).cpu()
        
        scores = []
        for memory in self.store.get_all():
            # Cosine similarity (already normalized)
            sim = (query_emb * memory.embedding).sum().item()
            scores.append(sim)
        
        scores = np.array(scores)
        top_indices = np.argsort(scores)[-top_k:][::-1]
        
        results = []
        memories = self.store.get_all()
        for idx in top_indices:
            results.append((memories[idx], float(scores[idx])))
        
        return results
    
    def retrieve_recent(self, n=5):
        return self.store.get_recent(n)

# =============================================================================
# TRAINING DATA
# =============================================================================

def generate_categorized_data():
    """Generate data with clear categories for contrastive learning"""
    
    categories = {
        'food': [
            'went to restaurant for dinner',
            'cooked pasta with tomato sauce',
            'had pizza with friends',
            'bought groceries milk and eggs',
            'ate breakfast cereal and coffee',
            'ordered takeout chinese food',
            'made sandwich for lunch',
            'baked cookies in oven',
            'went to coffee shop for latte',
            'had lunch at italian place'
        ],
        'work': [
            'attended team meeting about project',
            'fixed bug in the code',
            'reviewed pull request from colleague',
            'wrote documentation for API',
            'had video call with client',
            'worked on machine learning model',
            'debugged production issue',
            'presented results to stakeholders',
            'pair programmed with teammate',
            'deployed new feature to production'
        ],
        'reading': [
            'read paper about neural networks',
            'studied book on algorithms',
            'read article about AI trends',
            'learned about transformers from paper',
            'read blog post on optimization',
            'studied research on reinforcement learning',
            'read documentation for library',
            'reviewed literature on computer vision',
            'read tutorial on pytorch',
            'studied textbook chapter on ML'
        ],
        'social': [
            'called friend to catch up',
            'went to party with colleagues',
            'met up with old classmate',
            'had coffee with neighbor',
            'video chatted with family',
            'went to concert with friends',
            'attended networking event',
            'played games online with buddies',
            'went bowling with group',
            'celebrated birthday with friends'
        ],
        'exercise': [
            'went for run in park',
            'did yoga session at home',
            'went to gym for workout',
            'played basketball with friends',
            'went swimming at pool',
            'did weightlifting routine',
            'went hiking on trail',
            'rode bicycle around city',
            'did cardio on treadmill',
            'played tennis at court'
        ]
    }
    
    data = []
    base_time = datetime.now() - timedelta(days=30)
    
    # Generate balanced dataset
    for category, experiences in categories.items():
        for exp in experiences:
            # Multiple instances with variation
            for _ in range(3):
                time = base_time + timedelta(
                    days=random.randint(0, 30),
                    hours=random.randint(0, 23)
                )
                data.append((exp, time, category))
    
    return data

# =============================================================================
# TRAINING
# =============================================================================

def train_contrastive_memory(epochs=100):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}\n")
    
    print("Initializing memory system...")
    memory = ImprovedMemorySystem(vocab_size=10000, embed_dim=256, hidden_dim=256)
    
    print("Generating categorized training data...")
    data = generate_categorized_data()
    random.shuffle(data)
    
    print(f"Generated {len(data)} experiences across 5 categories\n")
    
    # Create category mapping
    category_to_id = {'food': 0, 'work': 1, 'reading': 2, 'social': 3, 'exercise': 4}
    
    # Prepare training batches
    batch_size = 32
    loss_fn = ContrastiveLoss(temperature=0.5)
    optimizer = torch.optim.AdamW(memory.encoder.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    print(f"Training encoder with contrastive learning for {epochs} epochs...\n")
    
    best_loss = float('inf')
    
    for epoch in range(epochs):
        memory.encoder.train()
        epoch_loss = 0
        n_batches = 0
        
        random.shuffle(data)
        
        for i in tqdm(range(0, len(data) - batch_size, batch_size), desc=f"Epoch {epoch+1}"):
            batch = data[i:i+batch_size]
            
            # Prepare batch
            texts = [item[0] for item in batch]
            categories = [category_to_id[item[2]] for item in batch]
            
            # Encode
            indices = torch.stack([memory._text_to_indices(t) for t in texts]).to(device)
            labels = torch.tensor(categories).to(device)
            
            optimizer.zero_grad()
            
            # Forward
            embeddings = memory.encoder(indices)
            
            # Contrastive loss
            loss = loss_fn(embeddings, labels)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(memory.encoder.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
            n_batches += 1
        
        scheduler.step()
        
        avg_loss = epoch_loss / max(n_batches, 1)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}: Loss: {avg_loss:.4f}")
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'encoder': memory.encoder.state_dict(),
                'vocab': memory.vocab
            }, 'memory_v2.pth')
    
    print(f"\n✅ Training complete - Best loss: {best_loss:.4f}")
    
    return memory

# =============================================================================
# TESTING
# =============================================================================

def test_improved_memory():
    print("\n" + "="*70)
    print("TESTING IMPROVED MEMORY")
    print("="*70)
    
    memory = ImprovedMemorySystem(vocab_size=10000, embed_dim=256, hidden_dim=256)
    
    checkpoint = torch.load('memory_v2.pth')
    memory.encoder.load_state_dict(checkpoint['encoder'])
    memory.vocab = checkpoint['vocab']
    
    print("\nStoring test experiences...")
    
    test_data = [
        ("went to coffee shop morning latte", datetime.now() - timedelta(days=1), "food"),
        ("team meeting about AI project deadline", datetime.now() - timedelta(days=2), "work"),
        ("bought groceries including milk bread eggs", datetime.now() - timedelta(days=3), "food"),
        ("read research paper about transformers", datetime.now() - timedelta(days=5), "reading"),
        ("had lunch at italian restaurant", datetime.now() - timedelta(days=7), "food"),
        ("fixed critical bug in production code", datetime.now() - timedelta(days=8), "work"),
        ("called friend catch up on life", datetime.now() - timedelta(days=9), "social"),
    ]
    
    for content, timestamp, category in test_data:
        memory.store_experience(content, timestamp, category)
        print(f"  ✓ [{category}] {content}")
    
    print(f"\nTotal: {memory.store.size()} memories")
    
    # Test queries
    print("\n" + "="*70)
    print("TEST: Semantic Retrieval")
    print("="*70)
    
    test_queries = [
        ("coffee breakfast", "food", "coffee"),
        ("coding programming", "work", "bug"),
        ("papers research", "reading", "paper"),
        ("dinner food eating", "food", "lunch"),
        ("meeting work", "work", "meeting")
    ]
    
    passed = 0
    
    for query, expected_cat, expected_word in test_queries:
        print(f"\nQuery: '{query}' (expecting {expected_cat})")
        results = memory.retrieve(query, top_k=3)
        
        print("Top 3:")
        for i, (ep, score) in enumerate(results, 1):
            print(f"  {i}. [{ep.category}] (score: {score:.3f}) {ep.content[:50]}")
        
        # Check if top result is correct category
        if results and results[0][0].category == expected_cat:
            print("  ✅ Correct category!")
            passed += 1
        elif results and expected_word in results[0][0].content:
            print("  ✅ Relevant content!")
            passed += 1
        else:
            print("  ❌ Wrong")
    
    # Results
    print("\n" + "="*70)
    print("RESULTS")
    print("="*70)
    
    print(f"\nTests: {passed}/5")
    
    if passed >= 4:
        print("\n✅ EXCELLENT - Semantic search working!")
        return True
    elif passed >= 3:
        print("\n✅ GOOD - Mostly working!")
        return True
    else:
        print("\n⚠️ Needs more work")
        return False

def main():
    memory = train_contrastive_memory(epochs=100)
    success = test_improved_memory()
    
    print("\n" + "="*70)
    if success:
        print("✅ EPISODIC MEMORY V2: WORKING")
        print("\nCapabilities:")
        print("  1. Store experiences with categories")
        print("  2. Semantic retrieval (contrastive learning)")
        print("  3. Temporal retrieval")
        print("\n✅ CAPABILITY #7 COMPLETE")

if __name__ == "__main__":
    main()
