"""
Eden Core - State Management & Persistence System
Addresses: State consistency, crash recovery, data persistence, rollback capabilities
"""

import json
import pickle
import os
import shutil
from typing import Any, Callable, Dict, List, Optional
from datetime import datetime
from dataclasses import dataclass, asdict
import hashlib


@dataclass
class StateSnapshot:
    """Immutable state snapshot"""
    snapshot_id: str
    timestamp: str
    state_data: Dict[str, Any]
    metadata: Dict[str, Any]
    checksum: str
    
    def verify_integrity(self) -> bool:
        """Verify snapshot integrity using checksum"""
        data_str = json.dumps(self.state_data, sort_keys=True)
        computed_checksum = hashlib.sha256(data_str.encode()).hexdigest()
        return computed_checksum == self.checksum


class StateManager:
    """Comprehensive state management with versioning and recovery"""
    
    def __init__(self, state_dir: str = "./eden_states"):
        self.state_dir = state_dir
        self.current_state: Dict[str, Any] = {}
        self.state_history: List[StateSnapshot] = []
        self.max_history = 100
        
        # Create state directory if it doesn't exist
        os.makedirs(state_dir, exist_ok=True)
        
        # Try to load last saved state
        self._load_latest_state()
    
    def update_state(self, key: str, value: Any, create_snapshot: bool = True):
        """Update a state value"""
        old_value = self.current_state.get(key)
        self.current_state[key] = value
        
        if create_snapshot:
            self.create_snapshot(metadata={
                'operation': 'update',
                'key': key,
                'old_value': str(old_value)[:100]  # Truncate for logging
            })
    
    def get_state(self, key: str, default: Any = None) -> Any:
        """Get a state value"""
        return self.current_state.get(key, default)
    
    def create_snapshot(self, metadata: Optional[Dict] = None) -> StateSnapshot:
        """Create an immutable snapshot of current state"""
        snapshot_id = hashlib.sha256(
            f"{datetime.now().isoformat()}{len(self.state_history)}".encode()
        ).hexdigest()[:16]
        
        # Create checksum of state data
        state_str = json.dumps(self.current_state, sort_keys=True)
        checksum = hashlib.sha256(state_str.encode()).hexdigest()
        
        snapshot = StateSnapshot(
            snapshot_id=snapshot_id,
            timestamp=datetime.now().isoformat(),
            state_data=self.current_state.copy(),
            metadata=metadata or {},
            checksum=checksum
        )
        
        self.state_history.append(snapshot)
        
        # Prune old snapshots if needed
        if len(self.state_history) > self.max_history:
            self.state_history = self.state_history[-self.max_history:]
        
        # Persist to disk
        self._save_snapshot(snapshot)
        
        return snapshot
    
    def rollback_to_snapshot(self, snapshot_id: str) -> bool:
        """Rollback to a specific snapshot"""
        for snapshot in reversed(self.state_history):
            if snapshot.snapshot_id == snapshot_id:
                if not snapshot.verify_integrity():
                    raise ValueError(f"Snapshot {snapshot_id} failed integrity check")
                
                self.current_state = snapshot.state_data.copy()
                self.create_snapshot(metadata={
                    'operation': 'rollback',
                    'target_snapshot': snapshot_id
                })
                return True
        
        return False
    
    def rollback_steps(self, steps: int = 1) -> bool:
        """Rollback a specific number of steps"""
        if steps >= len(self.state_history):
            return False
        
        target_snapshot = self.state_history[-(steps + 1)]
        return self.rollback_to_snapshot(target_snapshot.snapshot_id)
    
    def _save_snapshot(self, snapshot: StateSnapshot):
        """Save snapshot to disk"""
        filepath = os.path.join(self.state_dir, f"snapshot_{snapshot.snapshot_id}.json")
        
        with open(filepath, 'w') as f:
            json.dump({
                'snapshot_id': snapshot.snapshot_id,
                'timestamp': snapshot.timestamp,
                'state_data': snapshot.state_data,
                'metadata': snapshot.metadata,
                'checksum': snapshot.checksum
            }, f, indent=2)
    
    def _load_latest_state(self):
        """Load the most recent state from disk"""
        if not os.path.exists(self.state_dir):
            return
        
        snapshot_files = [f for f in os.listdir(self.state_dir) if f.startswith('snapshot_')]
        
        if not snapshot_files:
            return
        
        # Sort by modification time to get latest
        snapshot_files.sort(key=lambda x: os.path.getmtime(os.path.join(self.state_dir, x)))
        latest_file = snapshot_files[-1]
        
        try:
            with open(os.path.join(self.state_dir, latest_file), 'r') as f:
                data = json.load(f)
                snapshot = StateSnapshot(**data)
                
                if snapshot.verify_integrity():
                    self.current_state = snapshot.state_data.copy()
                    self.state_history.append(snapshot)
                    print(f"Loaded state from {snapshot.timestamp}")
                else:
                    print(f"Warning: Snapshot {snapshot.snapshot_id} failed integrity check")
        except Exception as e:
            print(f"Error loading state: {str(e)}")
    
    def get_state_history(self) -> List[Dict]:
        """Get history of state snapshots"""
        return [{
            'snapshot_id': s.snapshot_id,
            'timestamp': s.timestamp,
            'metadata': s.metadata
        } for s in self.state_history]


class PersistenceManager:
    """Manage persistence of critical data with backup and recovery"""
    
    def __init__(self, data_dir: str = "./eden_data"):
        self.data_dir = data_dir
        self.backup_dir = os.path.join(data_dir, "backups")
        
        os.makedirs(data_dir, exist_ok=True)
        os.makedirs(self.backup_dir, exist_ok=True)
    
    def save_data(self, name: str, data: Any, format: str = 'json'):
        """
        Save data with automatic backup of previous version
        
        Args:
            name: Data identifier
            data: Data to save
            format: 'json' or 'pickle'
        """
        filepath = os.path.join(self.data_dir, f"{name}.{format}")
        
        # Backup existing file if it exists
        if os.path.exists(filepath):
            backup_name = f"{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.{format}"
            backup_path = os.path.join(self.backup_dir, backup_name)
            shutil.copy2(filepath, backup_path)
            
            # Keep only last 10 backups per file
            self._prune_backups(name, format)
        
        # Save data
        if format == 'json':
            with open(filepath, 'w') as f:
                json.dump(data, f, indent=2)
        elif format == 'pickle':
            with open(filepath, 'wb') as f:
                pickle.dump(data, f)
        else:
            raise ValueError(f"Unsupported format: {format}")
    
    def load_data(self, name: str, format: str = 'json', default: Any = None) -> Any:
        """Load data with automatic fallback to backup if corrupted"""
        filepath = os.path.join(self.data_dir, f"{name}.{format}")
        
        # Try to load main file
        try:
            if format == 'json':
                with open(filepath, 'r') as f:
                    return json.load(f)
            elif format == 'pickle':
                with open(filepath, 'rb') as f:
                    return pickle.load(f)
        except Exception as e:
            print(f"Error loading {name}: {str(e)}")
            
            # Try to load from backup
            backup = self._load_latest_backup(name, format)
            if backup is not None:
                print(f"Loaded from backup for {name}")
                return backup
        
        return default
    
    def _load_latest_backup(self, name: str, format: str) -> Optional[Any]:
        """Load most recent backup"""
        backup_files = [
            f for f in os.listdir(self.backup_dir)
            if f.startswith(name) and f.endswith(f".{format}")
        ]
        
        if not backup_files:
            return None
        
        # Sort by timestamp in filename
        backup_files.sort(reverse=True)
        latest_backup = os.path.join(self.backup_dir, backup_files[0])
        
        try:
            if format == 'json':
                with open(latest_backup, 'r') as f:
                    return json.load(f)
            elif format == 'pickle':
                with open(latest_backup, 'rb') as f:
                    return pickle.load(f)
        except Exception as e:
            print(f"Error loading backup: {str(e)}")
            return None
    
    def _prune_backups(self, name: str, format: str, keep: int = 10):
        """Keep only the most recent N backups"""
        backup_files = [
            f for f in os.listdir(self.backup_dir)
            if f.startswith(name) and f.endswith(f".{format}")
        ]
        
        if len(backup_files) > keep:
            backup_files.sort()
            for old_backup in backup_files[:-keep]:
                os.remove(os.path.join(self.backup_dir, old_backup))


class CrashRecoveryManager:
    """Handle crash recovery and graceful shutdown"""
    
    def __init__(self, state_manager: StateManager, persistence_manager: PersistenceManager):
        self.state_manager = state_manager
        self.persistence_manager = persistence_manager
        self.recovery_log: List[Dict] = []
        self.shutdown_hooks: List[Callable] = []
    
    def register_shutdown_hook(self, hook: Callable):
        """Register a function to call during graceful shutdown"""
        self.shutdown_hooks.append(hook)
    
    def graceful_shutdown(self):
        """Perform graceful shutdown with state preservation"""
        print("Initiating graceful shutdown...")
        
        # Create final state snapshot
        snapshot = self.state_manager.create_snapshot(metadata={
            'operation': 'shutdown',
            'reason': 'graceful'
        })
        
        # Execute shutdown hooks
        for hook in self.shutdown_hooks:
            try:
                hook()
            except Exception as e:
                print(f"Shutdown hook failed: {str(e)}")
        
        # Save recovery information
        self.persistence_manager.save_data('last_shutdown', {
            'timestamp': datetime.now().isoformat(),
            'snapshot_id': snapshot.snapshot_id,
            'type': 'graceful'
        })
        
        print("Shutdown complete")
    
    def check_crash_recovery(self) -> bool:
        """Check if system needs crash recovery"""
        last_shutdown = self.persistence_manager.load_data('last_shutdown', default={})
        
        if not last_shutdown:
            return False
        
        if last_shutdown.get('type') == 'graceful':
            print("Last shutdown was graceful - no recovery needed")
            return False
        
        print("Crash detected - initiating recovery...")
        return True
    
    def perform_crash_recovery(self) -> bool:
        """Perform crash recovery from last good state"""
        try:
            # Load last known good state
            last_shutdown = self.persistence_manager.load_data('last_shutdown', default={})
            snapshot_id = last_shutdown.get('snapshot_id')
            
            if snapshot_id:
                success = self.state_manager.rollback_to_snapshot(snapshot_id)
                if success:
                    self.recovery_log.append({
                        'timestamp': datetime.now().isoformat(),
                        'type': 'crash_recovery',
                        'snapshot_id': snapshot_id,
                        'success': True
                    })
                    print(f"Successfully recovered to snapshot {snapshot_id}")
                    return True
            
            print("Could not find recovery snapshot")
            return False
            
        except Exception as e:
            self.recovery_log.append({
                'timestamp': datetime.now().isoformat(),
                'type': 'crash_recovery',
                'success': False,
                'error': str(e)
            })
            print(f"Recovery failed: {str(e)}")
            return False


class RobustStateSystem:
    """Unified state management system with full robustness"""
    
    def __init__(self):
        self.state_manager = StateManager()
        self.persistence_manager = PersistenceManager()
        self.crash_recovery = CrashRecoveryManager(
            self.state_manager,
            self.persistence_manager
        )
        
        # Check for crash recovery on initialization
        if self.crash_recovery.check_crash_recovery():
            self.crash_recovery.perform_crash_recovery()
    
    def get_system_report(self) -> str:
        """Generate comprehensive system report"""
        history = self.state_manager.get_state_history()
        
        report = f"""
Eden Core - State Management Report
{'='*50}
Generated: {datetime.now().isoformat()}

STATE MANAGEMENT:
{'-'*50}
Current State Keys: {len(self.state_manager.current_state)}
State History: {len(history)} snapshots

Recent Snapshots:
"""
        for snapshot in history[-5:]:
            report += f"  [{snapshot['timestamp']}] {snapshot['snapshot_id']}\n"
            if snapshot['metadata']:
                report += f"    Operation: {snapshot['metadata'].get('operation', 'N/A')}\n"
        
        report += f"\nRECOVERY LOG:\n{'-'*50}\n"
        for log in self.crash_recovery.recovery_log[-5:]:
            status = "✓" if log.get('success') else "✗"
            report += f"{status} [{log['timestamp']}] {log['type']}\n"
        
        return report


# Example usage
if __name__ == "__main__":
    system = RobustStateSystem()
    
    # Update state
    system.state_manager.update_state('learning_rate', 0.001)
    system.state_manager.update_state('model_version', 'v2.5')
    system.state_manager.update_state('capabilities', ['reasoning', 'learning', 'planning'])
    
    # Save some persistent data
    system.persistence_manager.save_data('config', {
        'max_tokens': 100000,
        'temperature': 0.7,
        'safety_level': 'high'
    })
    
    # Create a checkpoint
    snapshot = system.state_manager.create_snapshot(metadata={
        'operation': 'training_checkpoint',
        'epoch': 42
    })
    
    print(f"Created snapshot: {snapshot.snapshot_id}")
    
    # Simulate some state changes
    system.state_manager.update_state('learning_rate', 0.0005)
    system.state_manager.update_state('model_version', 'v2.6')
    
    # Rollback example
    print(f"\nRolling back to previous snapshot...")
    system.state_manager.rollback_to_snapshot(snapshot.snapshot_id)
    print(f"Learning rate after rollback: {system.state_manager.get_state('learning_rate')}")
    
    # Print report
    print("\n" + system.get_system_report())
    
    # Graceful shutdown
    system.crash_recovery.graceful_shutdown()
