"""
Eden Core - Robust Error Handling & Recovery System
Addresses: Exception handling, graceful degradation, state recovery
"""

import logging
import traceback
import json
from typing import Any, Callable, Dict, Optional, List
from enum import Enum
from datetime import datetime
import functools


class ErrorSeverity(Enum):
    """Error severity levels for proper handling"""
    CRITICAL = "critical"  # System must stop
    HIGH = "high"         # Major functionality impaired
    MEDIUM = "medium"     # Partial functionality lost
    LOW = "low"          # Minor issues, can continue


class RecoveryStrategy(Enum):
    """Recovery strategies for different error types"""
    RETRY = "retry"
    FALLBACK = "fallback"
    DEGRADE = "degrade"
    ABORT = "abort"
    CHECKPOINT_RESTORE = "checkpoint_restore"


class RobustErrorHandler:
    """Comprehensive error handling system with recovery strategies"""
    
    def __init__(self, log_file: str = "eden_errors.log"):
        self.log_file = log_file
        self.error_history: List[Dict] = []
        self.checkpoint_stack: List[Dict] = []
        self.fallback_handlers: Dict[str, Callable] = {}
        self.retry_policies: Dict[str, Dict] = {}
        
        # Setup logging
        logging.basicConfig(
            filename=log_file,
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger("EdenRobustHandler")
    
    def register_fallback(self, operation_name: str, handler: Callable):
        """Register a fallback handler for specific operations"""
        self.fallback_handlers[operation_name] = handler
        self.logger.info(f"Registered fallback for: {operation_name}")
    
    def set_retry_policy(self, operation_name: str, max_retries: int = 3, 
                        backoff_factor: float = 2.0):
        """Set retry policy for specific operations"""
        self.retry_policies[operation_name] = {
            'max_retries': max_retries,
            'backoff_factor': backoff_factor
        }
    
    def create_checkpoint(self, state: Dict[str, Any], label: str = "auto"):
        """Create a checkpoint of current system state"""
        checkpoint = {
            'timestamp': datetime.now().isoformat(),
            'label': label,
            'state': state
        }
        self.checkpoint_stack.append(checkpoint)
        self.logger.info(f"Checkpoint created: {label}")
        
        # Keep only last 10 checkpoints to prevent memory issues
        if len(self.checkpoint_stack) > 10:
            self.checkpoint_stack.pop(0)
    
    def restore_checkpoint(self, label: Optional[str] = None) -> Optional[Dict]:
        """Restore from most recent or specific checkpoint"""
        if not self.checkpoint_stack:
            self.logger.warning("No checkpoints available for restoration")
            return None
        
        if label:
            for checkpoint in reversed(self.checkpoint_stack):
                if checkpoint['label'] == label:
                    self.logger.info(f"Restored checkpoint: {label}")
                    return checkpoint['state']
            self.logger.warning(f"Checkpoint {label} not found")
            return None
        else:
            checkpoint = self.checkpoint_stack[-1]
            self.logger.info("Restored most recent checkpoint")
            return checkpoint['state']
    
    def handle_error(self, error: Exception, context: Dict[str, Any], 
                    severity: ErrorSeverity = ErrorSeverity.MEDIUM,
                    strategy: RecoveryStrategy = RecoveryStrategy.RETRY) -> Dict[str, Any]:
        """
        Comprehensive error handling with recovery strategies
        
        Returns: Dict with keys: success, result, recovery_action
        """
        error_record = {
            'timestamp': datetime.now().isoformat(),
            'error_type': type(error).__name__,
            'error_message': str(error),
            'context': context,
            'severity': severity.value,
            'strategy': strategy.value,
            'traceback': traceback.format_exc()
        }
        
        self.error_history.append(error_record)
        self.logger.error(f"Error occurred: {error_record}")
        
        # Execute recovery strategy
        if strategy == RecoveryStrategy.RETRY:
            return self._retry_operation(context)
        elif strategy == RecoveryStrategy.FALLBACK:
            return self._fallback_operation(context)
        elif strategy == RecoveryStrategy.DEGRADE:
            return self._degrade_gracefully(context)
        elif strategy == RecoveryStrategy.CHECKPOINT_RESTORE:
            return self._restore_from_checkpoint(context)
        elif strategy == RecoveryStrategy.ABORT:
            return {'success': False, 'error': str(error), 'recovery_action': 'aborted'}
        
        return {'success': False, 'error': str(error), 'recovery_action': 'none'}
    
    def _retry_operation(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Retry operation with exponential backoff"""
        operation_name = context.get('operation_name', 'unknown')
        policy = self.retry_policies.get(operation_name, {'max_retries': 3, 'backoff_factor': 2.0})
        
        for attempt in range(policy['max_retries']):
            try:
                self.logger.info(f"Retry attempt {attempt + 1} for {operation_name}")
                # Operation would be retried here
                return {'success': True, 'recovery_action': 'retry_successful', 'attempts': attempt + 1}
            except Exception as e:
                if attempt == policy['max_retries'] - 1:
                    return {'success': False, 'error': str(e), 'recovery_action': 'retry_failed'}
                import time
                time.sleep(policy['backoff_factor'] ** attempt)
        
        return {'success': False, 'recovery_action': 'retry_exhausted'}
    
    def _fallback_operation(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Execute fallback handler if available"""
        operation_name = context.get('operation_name', 'unknown')
        
        if operation_name in self.fallback_handlers:
            try:
                result = self.fallback_handlers[operation_name](context)
                self.logger.info(f"Fallback successful for {operation_name}")
                return {'success': True, 'result': result, 'recovery_action': 'fallback'}
            except Exception as e:
                self.logger.error(f"Fallback failed for {operation_name}: {str(e)}")
                return {'success': False, 'error': str(e), 'recovery_action': 'fallback_failed'}
        
        return {'success': False, 'recovery_action': 'no_fallback_available'}
    
    def _degrade_gracefully(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Degrade to simpler functionality"""
        self.logger.info("Degrading to reduced functionality mode")
        return {
            'success': True,
            'mode': 'degraded',
            'recovery_action': 'graceful_degradation',
            'message': 'Operating with reduced functionality'
        }
    
    def _restore_from_checkpoint(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Restore from most recent checkpoint"""
        state = self.restore_checkpoint()
        if state:
            return {'success': True, 'state': state, 'recovery_action': 'checkpoint_restored'}
        return {'success': False, 'recovery_action': 'no_checkpoint_available'}
    
    def get_error_statistics(self) -> Dict[str, Any]:
        """Get statistics about errors and recoveries"""
        if not self.error_history:
            return {'total_errors': 0}
        
        stats = {
            'total_errors': len(self.error_history),
            'by_severity': {},
            'by_type': {},
            'by_strategy': {},
            'recent_errors': self.error_history[-5:]
        }
        
        for error in self.error_history:
            # Count by severity
            severity = error['severity']
            stats['by_severity'][severity] = stats['by_severity'].get(severity, 0) + 1
            
            # Count by type
            error_type = error['error_type']
            stats['by_type'][error_type] = stats['by_type'].get(error_type, 0) + 1
            
            # Count by strategy
            strategy = error['strategy']
            stats['by_strategy'][strategy] = stats['by_strategy'].get(strategy, 0) + 1
        
        return stats


def robust_operation(severity: ErrorSeverity = ErrorSeverity.MEDIUM,
                     strategy: RecoveryStrategy = RecoveryStrategy.RETRY,
                     operation_name: str = None):
    """Decorator for making operations robust with automatic error handling"""
    def decorator(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            handler = RobustErrorHandler()
            
            try:
                return func(*args, **kwargs)
            except Exception as e:
                context = {
                    'operation_name': operation_name or func.__name__,
                    'args': str(args),
                    'kwargs': str(kwargs)
                }
                result = handler.handle_error(e, context, severity, strategy)
                
                if not result['success']:
                    raise Exception(f"Operation failed after recovery attempt: {result}")
                
                return result.get('result')
        
        return wrapper
    return decorator


# Example usage
if __name__ == "__main__":
    handler = RobustErrorHandler()
    
    # Register fallback
    def learning_fallback(context):
        return "Using cached knowledge instead of learning"
    
    handler.register_fallback("meta_learning", learning_fallback)
    handler.set_retry_policy("meta_learning", max_retries=5, backoff_factor=1.5)
    
    # Create checkpoint
    handler.create_checkpoint(
        {'capabilities': ['reasoning', 'learning'], 'state': 'active'},
        label="pre_operation"
    )
    
    # Simulate error handling
    try:
        raise ValueError("Simulated learning error")
    except Exception as e:
        result = handler.handle_error(
            e,
            {'operation_name': 'meta_learning', 'data': 'training_batch_1'},
            severity=ErrorSeverity.HIGH,
            strategy=RecoveryStrategy.FALLBACK
        )
        print(f"Recovery result: {result}")
    
    # Get statistics
    stats = handler.get_error_statistics()
    print(f"\nError Statistics: {json.dumps(stats, indent=2)}")
