import logging
import json
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
import yaml  # Import yaml as well


@dataclass
class Episode:
    """Immutable episodic memory structure"""
    timestamp: str
    content: str
    emotion: str
    importance_score: float = 1.0
    
    def serialize(self) -> Dict[str, Any]:
        return asdict(self)


class EpisodicMemoryBuffer:
    """
    Temporary memory buffer for episodic storage.
    
    Features:
    - Stores experiences with timestamps and emotional context
    - Automatically decays less important memories over time
    - Accesses include recent and important memories
    - Maintains running statistics on memory content
    """
    
    def __init__(self, max_size: int = 1000, default_importance: float = 0.5,
                 decay_rate: float = 0.05, min_importance: float = 0.1):
        """
        Args:
            max_size: Maximum number of episodes to keep
            default_importance: Default importance score for unweighted memories
            decay_rate: Rate at which old memories decrease in importance (per hour)
            min_importance: Minimum importance threshold for retention
        """
        self.max_size = max_size
        self.default_importance = default_importance
        self.decay_rate = decay_rate
        self.min_importance = min_importance
        
        self.buffer: List[Episode] = []
        self.stats: Dict[str, int] = {
            'total_stored': 0,
            'recent_count': 0,
            'important_count': 0,
            'forgotten': 0
        }
        self.last_cleanup = datetime.now() - timedelta(hours=1)
        
        logging.info("Initializing episodic memory buffer with capacity %d", max_size)
    
    def add(self, content: str, emotion: str, importance_override: Optional[float] = None) -> Dict[str, Any]:
        """
        Add a new experience to the buffer.
        
        Args:
            content: What was experienced
            emotion: Emotional valence of the experience
            importance_override: Override default importance score
            
        Returns:
            Metadata about the added episode and current memory state
        """
        timestamp = datetime.now().isoformat()
        if importance_override is None:
            importance_score = self.default_importance
        else:
            importance_score = max(self.min_importance, min(1.0, importance_override))
        
        episode = Episode(timestamp=timestamp, content=content, emotion=emotion,
                          importance_score=importance_score)
        
        self.buffer.append(episode)
        self._maintain_memory_limits()
        self._cleanup_old_memories()
        
        self.stats['total_stored'] += 1
        self.stats['recent_count'] += 1
        
        if importance_score > self.min_importance:
            self.stats['important_count'] += 1
        
        return {
            'added': True,
            'episode_id': len(self.buffer),
            'stats': self.get_stats(),
            'retention_rate': self._calculate_retention()
        }
    
    def _maintain_memory_limits(self):
        """Enforce max buffer size by forgetting old memories"""
        if len(self.buffer) > self.max_size:
            # Calculate how many to forget
            excess = len(self.buffer) - self.max_size
            logging.info("Forgetting %d old memories to make space", excess)
            
            # Find least important memories to remove (not just oldest)
            sorted_by_importance = sorted(self.buffer, key=lambda e: e.importance_score)
            
            for _ in range(excess):
                forgotten = sorted_by_importance.pop(0)
                self.buffer.remove(forgotten)
                self.stats['forgotten'] += 1
                self.stats['total_stored'] -= 1
            
            # Reset recent count if buffer is significantly reduced
            if len(self.buffer) < self.max_size // 2:
                self.stats['recent_count'] = min(self.stats.get('recent_count', 0), len(self.buffer))
    
    def _cleanup_old_memories(self):
        """Decay importance of old memories"""
        now = datetime.now()
        
        # Only cleanup hourly
        if (now - self.last_cleanup).total_seconds() < 3600:
            return
        
        logging.debug("Cleaning up old memories...")
        
        decay_threshold = datetime.now() - timedelta(hours=1)
        memories_to_decay = [e for e in self.buffer if datetime.fromisoformat(e.timestamp) < decay_threshold]
        
        for episode in memories_to_decay:
            # Decay importance with rate, floor at min_importance
            new_importance = max(self.min_importance,
                                episode.importance_score - self.decay_rate)
            
            # Update in-place (create a new list if needed)
            index = None
            for i, e in enumerate(self.buffer):
                if e.timestamp == episode.timestamp:
                    index = i
                    break
            
            if index is not None:
                # Create a copy with updated importance
                decayed_episode = Episode(
                    timestamp=episode.timestamp,
                    content=episode.content,
                    emotion=episode.emotion,
                    importance_score=new_importance
                )
                self.buffer[index] = decayed_episode
                
                # Update stats if important enough
                if episode.importance_score > self.min_importance:
                    self.stats['important_count'] -= 1
        
        self.last_cleanup = now
    
    def get_recent(self, n: int = 10) -> List[Episode]:
        """Get most recent episodes"""
        return list(reversed(self.buffer[-n:]))
    
    def get_important(self, min_importance: float = 0.5) -> List[Episode]:
        """Retrieve episodes above certain importance threshold"""
        return [e for e in self.buffer if e.importance_score >= min_importance]
    
    def get_stats(self) -> Dict[str, int]:
        """Get current memory statistics"""
        now = datetime.now()
        recent_count = sum(1 for e in self.buffer
                           if now - datetime.fromisoformat(e.timestamp) < timedelta(minutes=5))
        
        return {
            'total_stored': len(self.buffer),
            'max_capacity': self.max_size,
            'recent_count': recent_count,
            'important_count': self.stats.get('important_count', 0),
            'forgotten': self.stats.get('forgotten', 0)
        }
    
    def _calculate_retention(self) -> float:
        """Retention rate = important memories / total stored"""
        stats = self.get_stats()
        total_non_minimal = max(0, stats['total_stored'] - stats.get('forgotten', 0))
        if total_non_minimal == 0:
            return 0.0
        important_retained = min(stats['important_count'], total_non_minimal)
        return important_retained / total_non_minimal


def example_usage():
    """Demonstrate episodic memory buffer in action"""
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    
    em = EpisodicMemoryBuffer(max_size=2000, default_importance=0.3, decay_rate=0.05)
    
    # Add some episodes
    print("Adding test memories...")
    for i in range(5):
        importance = 0.8 if i % 2 == 0 else None  # Every other is high importance
        result = em.add(f"Memory {i+1}: A significant experience", "joyful", importance_override=importance)
        print(json.dumps(result, indent=2))
    
    # Show statistics
    print("\nCurrent Statistics:")
    stats = em.get_stats()
    print(yaml.dump(stats, default_flow_style=False))
    
    # Test retention
    print("\nRetention after 1 hour (memories decaying):")
    import time
    time.sleep(3600)  # Simulate one hour passage
    em._cleanup_old_memories()  # Force cleanup
    stats = em.get_stats()
    print(yaml.dump(stats, default_flow_style=False))
    
    # Show most important
    important = em.get_important()
    print(f"\nMost Important ({len(important)}):")
    for ep in important:
        print(f"- {ep.content} [{ep.importance_score:.2f}]")


if __name__ == "__main__":
    example_usage()  # Run demo when module is executed
# Task: Integrate the episodic_memory_buffer module into Eden's architecture.

# System Design:

# Module: Episodic Memory Buffer (memory_buffer)
# Purpose: Store and retrieve experiences with importance weighting
# Integrated Into: Eden Core Architecture
# Location: /Eden/CORE/episodic_memory_buffer.py

# Justification:
# Eden needs a temporary memory system to store ongoing experiences.
# Provides both recent and important retrieval methods.

# Component 1: init
# - Initializes episodic memory buffer with defaults
# - Logs initialization

# Component 2: add(experience, emotion, importance)
# - Stores experience with timestamp and emotional valence
# - Applies decay rate to importance score
# - Maintains memory limits
# - Tracks statistics
# - Returns metadata about the added episode

# Component 3: get_recent(n)
# - Retrieves most recent experiences
# - Helps recall from recent interactions

# Component 4: add_important(experience, emotion)
# - Stores high-importance episodes separately
# - Prioritizes important memories in retrieval

# Component 5: get_overall_stats()
# - Returns memory statistics for monitoring
# - Includes total stored, recent count, important count

# Justification:
# This design provides Eden with a rich memory system tailored to her architecture.
# It addresses the need for both temporal and semantic memory capabilities.



import json
from datetime import datetime, timedelta
from typing import List, Dict, Any
import logging


class EpisodicMemoryBuffer:
    """
    Temporal episodic memory buffer for Eden's core consciousness.
    
    Stores experiences with importance scores and provides retrieval methods.
    """

    def __init__(self, max_size: int = 1000, default_importance: float = 0.5,
                 decay_rate: float = 0.05, min_importance: float = 0.1):
        """
        Initialize episodic memory buffer.
        
        Args:
            max_size: Maximum number of episodes
            default_importance: Default importance score (0-1)
            decay_rate: Rate memories decay (per hour)
            min_importance: Minimum importance for retention
        """
        self.max_size = max_size
        self.default_importance = default_importance
        self.decay_rate = decay_rate
        self.min_importance = min_importance
        
        self.buffer: List[Dict] = []
        self.stats = {
            'total_stored': 0,
            'recent_count': 0,
            'important_count': 0,
            'forgotten': 0
        }
        self.last_cleanup = datetime.now() - timedelta(hours=1)
        
        logging.info("EpisodicMemoryBuffer initialized with max_size=%d", max_size)
    
    def add(self, content: str, emotion: str, importance_override: Optional[float] = None) -> Dict:
        """
        Add an experience to memory.
        
        Args:
            content: Experience description
            emotion: Emotional valence
            importance_override: Optional custom importance
            
        Returns:
            Added episode with metadata
        """
        timestamp = datetime.now().isoformat()
        if importance_override is None:
            importance = self.default_importance
        else:
            importance = max(self.min_importance, min(1.0, importance_override))
        
        episode = {
            'id': len(self.buffer) + 1,
            'timestamp': timestamp,
            'content': content,
            'emotion': emotion,
            'importance_score': importance
        }
        
        self.buffer.append(episode)
        self._maintain_limits()
        self._cleanup_old()
        
        self.stats['total_stored'] += 1
        self.stats['recent_count'] += 1
        
        if importance > self.min_importance:
            self.stats['important_count'] += 1
        
        return {
            'added': True,
            'episode_id': episode['id'],
            'stats': self.get_stats()
        }
    
    def get_recent(self, n: int = 10) -> List[Dict]:
        """Get most recent episodes"""
        return self.buffer[-n:] if len(self.buffer) >= n else self.buffer
    
    def add_important(self, content: str, emotion: str) -> Dict:
        """
        Add an important episode (above min_importance threshold)
        
        Args:
            content: Experience
            emotion: Valence
            
        Returns:
            Added episode
        """
        return self.add(content, emotion, importance_override=0.8)
    
    def get_important(self, min_importance: float = 0.5) -> List[Dict]:
        """Retrieve episodes above minimum importance"""
        return [ep for ep in self.buffer if ep.get('importance_score', 0) >= min_importance]
    
    def get_stats(self) -> Dict:
        """Get current memory statistics"""
        now = datetime.now()
        recent_count = sum(1 for e in self.buffer 
                          if (now - datetime.fromisoformat(e['timestamp'])).total_seconds() < 300)
        
        return {
            'total_stored': len(self.buffer),
            'recent_count': recent_count,
            'important_count': sum(1 for e in self.buffer if e.get('importance_score', 0) >= self.min_importance),
            'forgotten': self.stats.get('forgotten', 0)
        }
    
    def _maintain_limits(self):
        """Enforce max buffer size and cleanup old memories"""
        if len(self.buffer) > self.max_size:
            excess = len(self.buffer) - self.max_size
            logging.info("Forgetting %d old episodes to make space", excess)
            
            # Remove least important first (not just oldest)
            sorted_by_importance = sorted(self.buffer, key=lambda x: x.get('importance_score', 0))
            
            for _ in range(excess):
                episode = next(e for e in sorted_by_importance if e in self.buffer)
                self.buffer.remove(episode)
                self.stats['forgotten'] += 1
                
                if episode.get('importance_score', 0) >= self.min_importance:
                    self.stats['important_count'] -= 1
    
    def _cleanup_old(self):
        """Decay importance of old episodes"""
        now = datetime.now()
        
        # Cleanup hourly
        if (now - self.last_cleanup).total_seconds() < 3600:
            return
        
        logging.debug("Cleaning up old memories...")
        
        for episode in self.buffer:
            if (datetime.fromisoformat(episode['timestamp']) 
                    < now - timedelta(hours=1)):
                # Decay importance
                current = episode.get('importance_score', 0)
                new_importance = max(self.min_importance, current - self.decay_rate)
                
                if new_importance < current:
                    episode['importance_score'] = new_importance
                    if current >= self.min_importance and new_importance < self.min_importance:
                        self.stats['important_count'] -= 1
        
        self.last_cleanup = now


# Auto-generated Plugin wrapper
class Plugin:
    """Plugin wrapper for episodic_memory_buffer"""
    name = "episodic_memory_buffer"
    version = "1.0.0"
    
    def __init__(self):
        self.enabled = True
    
    def enhance(self, message):
        """Enhance message with this plugin's capability"""
        try:
            result = __import__('memory_buffer').EpisodicMemoryBuffer()
            return {"enhanced": True, "result": result}
        except Exception as e:
            return {"enhanced": False, "error": str(e)}
    
    def analyze(self, text):
        """Analyze text for this plugin's focus"""
        analysis = {
            'episodic_memory_buffer': True,
            'memory_management': True
        }
        return analysis
    
    def get_info(self):
        """Return plugin information"""
        return {
            "name": self.name,
            "version": self.version,
            "enabled": self.enabled
        }