"""
Eden Core - Validation & Monitoring System
Addresses: Input validation, output verification, system health monitoring
"""

import time
import psutil
import numpy as np
from typing import Any, Dict, List, Optional, Callable
from dataclasses import dataclass
from datetime import datetime
import json


@dataclass
class HealthMetrics:
    """System health metrics"""
    cpu_percent: float
    memory_percent: float
    timestamp: str
    operation_count: int
    error_rate: float
    average_response_time: float
    
    def is_healthy(self, thresholds: Dict[str, float]) -> bool:
        """Check if system is within healthy parameters"""
        return (
            self.cpu_percent < thresholds.get('cpu', 80.0) and
            self.memory_percent < thresholds.get('memory', 85.0) and
            self.error_rate < thresholds.get('error_rate', 0.1) and
            self.average_response_time < thresholds.get('response_time', 5.0)
        )


@dataclass
class ValidationRule:
    """Validation rule definition"""
    name: str
    validator: Callable
    error_message: str
    severity: str = "medium"


class InputValidator:
    """Comprehensive input validation system"""
    
    def __init__(self):
        self.rules: Dict[str, List[ValidationRule]] = {}
        self.validation_history: List[Dict] = []
    
    def register_rule(self, input_type: str, rule: ValidationRule):
        """Register a validation rule for specific input type"""
        if input_type not in self.rules:
            self.rules[input_type] = []
        self.rules[input_type].append(rule)
    
    def validate(self, input_type: str, data: Any) -> Dict[str, Any]:
        """
        Validate input data against registered rules
        
        Returns: Dict with keys: valid, errors, warnings
        """
        if input_type not in self.rules:
            return {'valid': True, 'errors': [], 'warnings': ['No validation rules defined']}
        
        errors = []
        warnings = []
        
        for rule in self.rules[input_type]:
            try:
                if not rule.validator(data):
                    if rule.severity == "critical":
                        errors.append(f"{rule.name}: {rule.error_message}")
                    else:
                        warnings.append(f"{rule.name}: {rule.error_message}")
            except Exception as e:
                errors.append(f"{rule.name}: Validation failed with error - {str(e)}")
        
        result = {
            'valid': len(errors) == 0,
            'errors': errors,
            'warnings': warnings,
            'timestamp': datetime.now().isoformat()
        }
        
        self.validation_history.append({
            'input_type': input_type,
            'result': result
        })
        
        return result


class OutputVerifier:
    """Output verification and quality assurance"""
    
    def __init__(self):
        self.verification_rules: Dict[str, List[Callable]] = {}
        self.quality_thresholds: Dict[str, float] = {}
    
    def register_verifier(self, output_type: str, verifier: Callable):
        """Register an output verification function"""
        if output_type not in self.verification_rules:
            self.verification_rules[output_type] = []
        self.verification_rules[output_type].append(verifier)
    
    def set_quality_threshold(self, metric: str, threshold: float):
        """Set quality threshold for output metrics"""
        self.quality_thresholds[metric] = threshold
    
    def verify(self, output_type: str, data: Any) -> Dict[str, Any]:
        """
        Verify output meets quality standards
        
        Returns: Dict with keys: verified, quality_score, issues
        """
        if output_type not in self.verification_rules:
            return {'verified': True, 'quality_score': 1.0, 'issues': []}
        
        issues = []
        quality_scores = []
        
        for verifier in self.verification_rules[output_type]:
            try:
                result = verifier(data)
                if isinstance(result, dict):
                    quality_scores.append(result.get('score', 1.0))
                    if 'issues' in result:
                        issues.extend(result['issues'])
                elif isinstance(result, bool):
                    quality_scores.append(1.0 if result else 0.0)
                else:
                    quality_scores.append(float(result))
            except Exception as e:
                issues.append(f"Verification error: {str(e)}")
                quality_scores.append(0.0)
        
        avg_quality = np.mean(quality_scores) if quality_scores else 0.0
        
        return {
            'verified': avg_quality >= self.quality_thresholds.get('overall', 0.7) and len(issues) == 0,
            'quality_score': avg_quality,
            'issues': issues,
            'timestamp': datetime.now().isoformat()
        }


class SystemMonitor:
    """Comprehensive system monitoring and health checks"""
    
    def __init__(self, check_interval: int = 60):
        self.check_interval = check_interval
        self.metrics_history: List[HealthMetrics] = []
        self.operation_times: List[float] = []
        self.error_count = 0
        self.operation_count = 0
        self.alerts: List[Dict] = []
        self.health_thresholds = {
            'cpu': 80.0,
            'memory': 85.0,
            'error_rate': 0.1,
            'response_time': 5.0
        }
    
    def record_operation(self, duration: float, success: bool = True):
        """Record an operation for monitoring"""
        self.operation_times.append(duration)
        self.operation_count += 1
        if not success:
            self.error_count += 1
        
        # Keep only last 1000 operations
        if len(self.operation_times) > 1000:
            self.operation_times.pop(0)
    
    def collect_metrics(self) -> HealthMetrics:
        """Collect current system health metrics"""
        cpu_percent = psutil.cpu_percent(interval=1)
        memory = psutil.virtual_memory()
        
        avg_response_time = np.mean(self.operation_times[-100:]) if self.operation_times else 0.0
        error_rate = self.error_count / max(self.operation_count, 1)
        
        metrics = HealthMetrics(
            cpu_percent=cpu_percent,
            memory_percent=memory.percent,
            timestamp=datetime.now().isoformat(),
            operation_count=self.operation_count,
            error_rate=error_rate,
            average_response_time=avg_response_time
        )
        
        self.metrics_history.append(metrics)
        
        # Keep only last 1000 metrics
        if len(self.metrics_history) > 1000:
            self.metrics_history.pop(0)
        
        # Check health and create alerts
        if not metrics.is_healthy(self.health_thresholds):
            self._create_alert(metrics)
        
        return metrics
    
    def _create_alert(self, metrics: HealthMetrics):
        """Create alert for unhealthy conditions"""
        alert = {
            'timestamp': metrics.timestamp,
            'type': 'health_warning',
            'details': {
                'cpu': metrics.cpu_percent,
                'memory': metrics.memory_percent,
                'error_rate': metrics.error_rate,
                'response_time': metrics.average_response_time
            }
        }
        self.alerts.append(alert)
    
    def get_health_status(self) -> Dict[str, Any]:
        """Get current health status summary"""
        if not self.metrics_history:
            metrics = self.collect_metrics()
        else:
            metrics = self.metrics_history[-1]
        
        recent_metrics = self.metrics_history[-10:] if len(self.metrics_history) >= 10 else self.metrics_history
        
        return {
            'current': {
                'cpu': metrics.cpu_percent,
                'memory': metrics.memory_percent,
                'error_rate': metrics.error_rate,
                'response_time': metrics.average_response_time
            },
            'healthy': metrics.is_healthy(self.health_thresholds),
            'trends': {
                'cpu_trend': self._calculate_trend([m.cpu_percent for m in recent_metrics]),
                'memory_trend': self._calculate_trend([m.memory_percent for m in recent_metrics]),
                'error_trend': self._calculate_trend([m.error_rate for m in recent_metrics])
            },
            'alerts': len(self.alerts),
            'operations': self.operation_count
        }
    
    def _calculate_trend(self, values: List[float]) -> str:
        """Calculate trend direction from values"""
        if len(values) < 2:
            return "stable"
        
        recent_avg = np.mean(values[-3:])
        older_avg = np.mean(values[:-3]) if len(values) > 3 else values[0]
        
        diff = recent_avg - older_avg
        
        if abs(diff) < 0.05 * older_avg:
            return "stable"
        elif diff > 0:
            return "increasing"
        else:
            return "decreasing"
    
    def get_diagnostics(self) -> Dict[str, Any]:
        """Get detailed diagnostic information"""
        return {
            'system_info': {
                'cpu_count': psutil.cpu_count(),
                'total_memory_gb': psutil.virtual_memory().total / (1024**3),
                'available_memory_gb': psutil.virtual_memory().available / (1024**3)
            },
            'performance': {
                'total_operations': self.operation_count,
                'total_errors': self.error_count,
                'error_rate': self.error_count / max(self.operation_count, 1),
                'avg_response_time': np.mean(self.operation_times) if self.operation_times else 0
            },
            'alerts': self.alerts[-10:],  # Last 10 alerts
            'health_status': self.get_health_status()
        }


class RobustSystemManager:
    """Unified system management with validation, verification, and monitoring"""
    
    def __init__(self):
        self.input_validator = InputValidator()
        self.output_verifier = OutputVerifier()
        self.system_monitor = SystemMonitor()
        
        # Setup default validation rules
        self._setup_default_rules()
    
    def _setup_default_rules(self):
        """Setup default validation and verification rules"""
        # Input validation rules
        self.input_validator.register_rule(
            "query",
            ValidationRule(
                name="query_length",
                validator=lambda x: len(str(x)) > 0 and len(str(x)) < 10000,
                error_message="Query must be between 1 and 10000 characters",
                severity="critical"
            )
        )
        
        self.input_validator.register_rule(
            "confidence_score",
            ValidationRule(
                name="score_range",
                validator=lambda x: 0.0 <= x <= 1.0,
                error_message="Confidence score must be between 0 and 1",
                severity="critical"
            )
        )
        
        # Output verification rules
        self.output_verifier.set_quality_threshold('overall', 0.7)
        self.output_verifier.set_quality_threshold('coherence', 0.8)
    
    def process_with_validation(self, input_type: str, data: Any, 
                               operation: Callable, output_type: str) -> Dict[str, Any]:
        """
        Process data with full validation, monitoring, and verification
        
        Returns: Dict with keys: success, result, validation, verification, metrics
        """
        start_time = time.time()
        
        # Step 1: Validate input
        validation_result = self.input_validator.validate(input_type, data)
        if not validation_result['valid']:
            self.system_monitor.record_operation(time.time() - start_time, success=False)
            return {
                'success': False,
                'error': 'Input validation failed',
                'validation': validation_result
            }
        
        # Step 2: Execute operation
        try:
            result = operation(data)
        except Exception as e:
            self.system_monitor.record_operation(time.time() - start_time, success=False)
            return {
                'success': False,
                'error': str(e),
                'validation': validation_result
            }
        
        # Step 3: Verify output
        verification_result = self.output_verifier.verify(output_type, result)
        
        # Step 4: Record metrics
        duration = time.time() - start_time
        self.system_monitor.record_operation(duration, success=verification_result['verified'])
        
        return {
            'success': verification_result['verified'],
            'result': result,
            'validation': validation_result,
            'verification': verification_result,
            'metrics': {
                'duration': duration,
                'timestamp': datetime.now().isoformat()
            }
        }
    
    def get_system_report(self) -> str:
        """Generate comprehensive system report"""
        health = self.system_monitor.get_health_status()
        diagnostics = self.system_monitor.get_diagnostics()
        
        report = f"""
Eden Core - System Robustness Report
{'='*50}
Generated: {datetime.now().isoformat()}

HEALTH STATUS: {'HEALTHY' if health['healthy'] else 'UNHEALTHY'}
{'-'*50}
Current Metrics:
  CPU Usage: {health['current']['cpu']:.1f}%
  Memory Usage: {health['current']['memory']:.1f}%
  Error Rate: {health['current']['error_rate']:.2%}
  Avg Response Time: {health['current']['response_time']:.3f}s

Trends:
  CPU: {health['trends']['cpu_trend']}
  Memory: {health['trends']['memory_trend']}
  Errors: {health['trends']['error_trend']}

PERFORMANCE SUMMARY:
{'-'*50}
Total Operations: {diagnostics['performance']['total_operations']}
Total Errors: {diagnostics['performance']['total_errors']}
Overall Error Rate: {diagnostics['performance']['error_rate']:.2%}
Average Response Time: {diagnostics['performance']['avg_response_time']:.3f}s

ALERTS: {len(diagnostics['alerts'])} recent alerts

RECOMMENDATIONS:
{'-'*50}
"""
        
        # Add recommendations based on health
        if health['current']['cpu'] > 70:
            report += "⚠ High CPU usage - Consider optimizing compute-intensive operations\n"
        if health['current']['memory'] > 75:
            report += "⚠ High memory usage - Review memory management and caching strategies\n"
        if health['current']['error_rate'] > 0.05:
            report += "⚠ Elevated error rate - Review error logs and improve error handling\n"
        if health['current']['response_time'] > 3.0:
            report += "⚠ Slow response times - Optimize critical path operations\n"
        
        if health['healthy']:
            report += "✓ All systems operating within normal parameters\n"
        
        return report


# Example usage
if __name__ == "__main__":
    manager = RobustSystemManager()
    
    # Example operation
    def example_operation(data):
        return {"processed": data, "confidence": 0.95}
    
    # Process with full validation
    result = manager.process_with_validation(
        input_type="query",
        data="What is artificial general intelligence?",
        operation=example_operation,
        output_type="reasoning_result"
    )
    
    print(json.dumps(result, indent=2))
    print("\n" + manager.get_system_report())
