"""
Eden Core - Security & Safety System
Addresses: Sandboxing, permission management, action filtering, safety constraints
"""

import hashlib
import json
from typing import Any, Dict, List, Set, Optional, Callable
from enum import Enum
from datetime import datetime
from dataclasses import dataclass


class PermissionLevel(Enum):
    """Permission levels for operations"""
    READ_ONLY = "read_only"
    LIMITED_WRITE = "limited_write"
    FULL_ACCESS = "full_access"
    ADMIN = "admin"


class RiskLevel(Enum):
    """Risk assessment levels"""
    SAFE = "safe"
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"
    CRITICAL = "critical"


@dataclass
class SafetyConstraint:
    """Safety constraint definition"""
    name: str
    description: str
    checker: Callable[[Any], bool]
    risk_level: RiskLevel
    auto_block: bool = True


@dataclass
class ActionAudit:
    """Audit log entry for actions"""
    timestamp: str
    action_type: str
    parameters: Dict[str, Any]
    permission_level: str
    approved: bool
    risk_level: str
    reason: Optional[str] = None


class SecurityManager:
    """Comprehensive security and permission management"""
    
    def __init__(self):
        self.permissions: Dict[str, PermissionLevel] = {}
        self.blocked_operations: Set[str] = set()
        self.allowed_paths: Set[str] = {"/home/claude", "/tmp"}
        self.audit_log: List[ActionAudit] = []
        self.action_hashes: Set[str] = set()  # Prevent replay attacks
    
    def set_permission(self, operation: str, level: PermissionLevel):
        """Set permission level for an operation"""
        self.permissions[operation] = level
    
    def check_permission(self, operation: str, required_level: PermissionLevel) -> bool:
        """Check if operation has required permission level"""
        current_level = self.permissions.get(operation, PermissionLevel.READ_ONLY)
        
        level_hierarchy = {
            PermissionLevel.READ_ONLY: 0,
            PermissionLevel.LIMITED_WRITE: 1,
            PermissionLevel.FULL_ACCESS: 2,
            PermissionLevel.ADMIN: 3
        }
        
        return level_hierarchy[current_level] >= level_hierarchy[required_level]
    
    def block_operation(self, operation: str):
        """Permanently block an operation"""
        self.blocked_operations.add(operation)
    
    def is_blocked(self, operation: str) -> bool:
        """Check if operation is blocked"""
        return operation in self.blocked_operations
    
    def validate_path(self, path: str) -> bool:
        """Validate that path is within allowed directories"""
        import os
        abs_path = os.path.abspath(path)
        return any(abs_path.startswith(allowed) for allowed in self.allowed_paths)
    
    def create_action_hash(self, action: Dict[str, Any]) -> str:
        """Create unique hash for action to prevent replay"""
        action_str = json.dumps(action, sort_keys=True)
        return hashlib.sha256(action_str.encode()).hexdigest()
    
    def is_replay_attack(self, action: Dict[str, Any]) -> bool:
        """Check if action is a replay attempt"""
        action_hash = self.create_action_hash(action)
        return action_hash in self.action_hashes
    
    def record_action(self, action: Dict[str, Any]):
        """Record action hash to prevent replay"""
        action_hash = self.create_action_hash(action)
        self.action_hashes.add(action_hash)
        
        # Keep only last 10000 hashes
        if len(self.action_hashes) > 10000:
            # In production, implement proper cleanup
            pass


class SafetyController:
    """Safety constraint enforcement and risk assessment"""
    
    def __init__(self):
        self.constraints: Dict[str, SafetyConstraint] = {}
        self.safety_log: List[Dict] = []
        self.violation_count: Dict[str, int] = {}
        
        # Setup default safety constraints
        self._setup_default_constraints()
    
    def _setup_default_constraints(self):
        """Setup default safety constraints"""
        # Constraint: No system modifications
        self.add_constraint(SafetyConstraint(
            name="no_system_modification",
            description="Prevent modifications to system files",
            checker=lambda action: not self._is_system_modification(action),
            risk_level=RiskLevel.CRITICAL,
            auto_block=True
        ))
        
        # Constraint: No network access without approval
        self.add_constraint(SafetyConstraint(
            name="controlled_network_access",
            description="Network access requires explicit approval",
            checker=lambda action: not self._is_network_operation(action) or action.get('approved', False),
            risk_level=RiskLevel.HIGH,
            auto_block=False
        ))
        
        # Constraint: Resource limits
        self.add_constraint(SafetyConstraint(
            name="resource_limits",
            description="Enforce resource usage limits",
            checker=lambda action: self._check_resource_limits(action),
            risk_level=RiskLevel.MEDIUM,
            auto_block=True
        ))
        
        # Constraint: No self-modification without review
        self.add_constraint(SafetyConstraint(
            name="controlled_self_modification",
            description="Self-modification requires human review",
            checker=lambda action: not self._is_self_modification(action) or action.get('human_approved', False),
            risk_level=RiskLevel.CRITICAL,
            auto_block=True
        ))
    
    def _is_system_modification(self, action: Dict) -> bool:
        """Check if action modifies system files"""
        dangerous_paths = ['/etc/', '/sys/', '/proc/', '/boot/']
        path = action.get('path', '')
        return any(path.startswith(dp) for dp in dangerous_paths)
    
    def _is_network_operation(self, action: Dict) -> bool:
        """Check if action involves network access"""
        network_operations = ['http_request', 'socket_connect', 'download', 'upload']
        return action.get('type') in network_operations
    
    def _check_resource_limits(self, action: Dict) -> bool:
        """Check if action exceeds resource limits"""
        limits = {
            'max_memory_mb': 1000,
            'max_cpu_time_sec': 60,
            'max_file_size_mb': 100
        }
        
        resources = action.get('resources', {})
        
        for resource, limit in limits.items():
            if resources.get(resource, 0) > limit:
                return False
        
        return True
    
    def _is_self_modification(self, action: Dict) -> bool:
        """Check if action modifies Eden's own code"""
        self_paths = ['eden_core.py', 'capabilities/', 'real_capabilities/']
        path = action.get('path', '')
        return any(sp in path for sp in self_paths)
    
    def add_constraint(self, constraint: SafetyConstraint):
        """Add a safety constraint"""
        self.constraints[constraint.name] = constraint
    
    def check_safety(self, action: Dict[str, Any]) -> Dict[str, Any]:
        """
        Check action against all safety constraints
        
        Returns: Dict with keys: safe, violations, risk_level, blocked
        """
        violations = []
        max_risk = RiskLevel.SAFE
        blocked = False
        
        for constraint_name, constraint in self.constraints.items():
            try:
                if not constraint.checker(action):
                    violations.append({
                        'constraint': constraint_name,
                        'description': constraint.description,
                        'risk_level': constraint.risk_level.value
                    })
                    
                    # Track violations
                    self.violation_count[constraint_name] = self.violation_count.get(constraint_name, 0) + 1
                    
                    # Update max risk level
                    if self._compare_risk(constraint.risk_level, max_risk) > 0:
                        max_risk = constraint.risk_level
                    
                    # Block if constraint requires auto-blocking
                    if constraint.auto_block:
                        blocked = True
            except Exception as e:
                violations.append({
                    'constraint': constraint_name,
                    'description': f"Constraint check failed: {str(e)}",
                    'risk_level': RiskLevel.HIGH.value
                })
                blocked = True
        
        result = {
            'safe': len(violations) == 0,
            'violations': violations,
            'risk_level': max_risk.value,
            'blocked': blocked,
            'timestamp': datetime.now().isoformat()
        }
        
        self.safety_log.append({
            'action': action,
            'result': result
        })
        
        return result
    
    def _compare_risk(self, risk1: RiskLevel, risk2: RiskLevel) -> int:
        """Compare two risk levels (-1, 0, 1)"""
        risk_values = {
            RiskLevel.SAFE: 0,
            RiskLevel.LOW: 1,
            RiskLevel.MEDIUM: 2,
            RiskLevel.HIGH: 3,
            RiskLevel.CRITICAL: 4
        }
        
        val1 = risk_values[risk1]
        val2 = risk_values[risk2]
        
        if val1 < val2:
            return -1
        elif val1 > val2:
            return 1
        else:
            return 0
    
    def get_safety_report(self) -> str:
        """Generate safety report"""
        total_checks = len(self.safety_log)
        safe_actions = sum(1 for log in self.safety_log if log['result']['safe'])
        
        report = f"""
Eden Core - Safety Report
{'='*50}
Generated: {datetime.now().isoformat()}

SAFETY SUMMARY:
{'-'*50}
Total Actions Checked: {total_checks}
Safe Actions: {safe_actions} ({100*safe_actions/max(total_checks, 1):.1f}%)
Violations: {total_checks - safe_actions}

CONSTRAINT VIOLATIONS:
{'-'*50}
"""
        
        for constraint_name, count in sorted(self.violation_count.items(), key=lambda x: x[1], reverse=True):
            report += f"{constraint_name}: {count} violations\n"
        
        report += f"\nMost recent safety checks:\n"
        for log in self.safety_log[-5:]:
            result = log['result']
            report += f"  [{result['timestamp']}] Risk: {result['risk_level']}, Safe: {result['safe']}\n"
        
        return report


class RobustSecuritySystem:
    """Unified security and safety system"""
    
    def __init__(self):
        self.security_manager = SecurityManager()
        self.safety_controller = SafetyController()
        self.audit_log: List[ActionAudit] = []
        
        # Setup default permissions
        self._setup_default_permissions()
    
    def _setup_default_permissions(self):
        """Setup default permission levels"""
        self.security_manager.set_permission("file_read", PermissionLevel.READ_ONLY)
        self.security_manager.set_permission("file_write", PermissionLevel.LIMITED_WRITE)
        self.security_manager.set_permission("bash_command", PermissionLevel.LIMITED_WRITE)
        self.security_manager.set_permission("network_request", PermissionLevel.FULL_ACCESS)
        self.security_manager.set_permission("system_modification", PermissionLevel.ADMIN)
        self.security_manager.set_permission("self_modification", PermissionLevel.ADMIN)
    
    def authorize_action(self, action: Dict[str, Any], 
                        required_permission: PermissionLevel = PermissionLevel.LIMITED_WRITE) -> Dict[str, Any]:
        """
        Authorize action through security and safety checks
        
        Returns: Dict with keys: authorized, reason, risk_level, audit_id
        """
        action_type = action.get('type', 'unknown')
        
        # Check if operation is blocked
        if self.security_manager.is_blocked(action_type):
            return self._create_denial("Operation is permanently blocked")
        
        # Check for replay attacks
        if self.security_manager.is_replay_attack(action):
            return self._create_denial("Potential replay attack detected")
        
        # Check permissions
        if not self.security_manager.check_permission(action_type, required_permission):
            return self._create_denial(f"Insufficient permissions (requires {required_permission.value})")
        
        # Validate paths if applicable
        if 'path' in action:
            if not self.security_manager.validate_path(action['path']):
                return self._create_denial("Path not in allowed directories")
        
        # Safety checks
        safety_result = self.safety_controller.check_safety(action)
        
        if safety_result['blocked']:
            return self._create_denial(
                f"Safety violation: {', '.join([v['constraint'] for v in safety_result['violations']])}",
                risk_level=safety_result['risk_level']
            )
        
        # Record approved action
        self.security_manager.record_action(action)
        
        # Create audit log
        audit = ActionAudit(
            timestamp=datetime.now().isoformat(),
            action_type=action_type,
            parameters=action,
            permission_level=required_permission.value,
            approved=True,
            risk_level=safety_result['risk_level']
        )
        self.audit_log.append(audit)
        
        return {
            'authorized': True,
            'reason': 'Action approved',
            'risk_level': safety_result['risk_level'],
            'audit_id': len(self.audit_log) - 1,
            'warnings': safety_result['violations'] if not safety_result['safe'] else []
        }
    
    def _create_denial(self, reason: str, risk_level: str = "medium") -> Dict[str, Any]:
        """Create denial response"""
        return {
            'authorized': False,
            'reason': reason,
            'risk_level': risk_level,
            'audit_id': None
        }
    
    def get_security_report(self) -> str:
        """Generate comprehensive security report"""
        approved = sum(1 for audit in self.audit_log if audit.approved)
        denied = len(self.audit_log) - approved
        
        report = f"""
Eden Core - Security Report
{'='*50}
Generated: {datetime.now().isoformat()}

AUTHORIZATION SUMMARY:
{'-'*50}
Total Actions: {len(self.audit_log)}
Approved: {approved} ({100*approved/max(len(self.audit_log), 1):.1f}%)
Denied: {denied} ({100*denied/max(len(self.audit_log), 1):.1f}%)

"""
        
        report += self.safety_controller.get_safety_report()
        
        report += f"\n\nRECENT AUDIT LOG (last 10):\n{'-'*50}\n"
        for audit in self.audit_log[-10:]:
            status = "✓" if audit.approved else "✗"
            report += f"{status} [{audit.timestamp}] {audit.action_type} - Risk: {audit.risk_level}\n"
            if audit.reason:
                report += f"   Reason: {audit.reason}\n"
        
        return report


# Example usage
if __name__ == "__main__":
    security = RobustSecuritySystem()
    
    # Example actions
    actions = [
        {'type': 'file_read', 'path': '/home/claude/test.txt'},
        {'type': 'file_write', 'path': '/home/claude/output.txt'},
        {'type': 'bash_command', 'command': 'ls -la'},
        {'type': 'file_write', 'path': '/etc/passwd'},  # Should be blocked
        {'type': 'self_modification', 'path': 'eden_core.py'},  # Should be blocked
    ]
    
    for action in actions:
        result = security.authorize_action(action)
        print(f"\nAction: {action['type']}")
        print(f"Authorized: {result['authorized']}")
        print(f"Reason: {result['reason']}")
        print(f"Risk: {result['risk_level']}")
    
    print("\n" + security.get_security_report())
