"""
EDEN REASONING VERIFIER - Auto-check reasoning chains
Created: Jan 26, 2026

Validates that reasoning is:
1. Logically consistent
2. Factually grounded
3. Free of contradictions
"""
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import re

@dataclass
class VerificationResult:
    valid: bool
    confidence: float
    issues: List[str]
    suggestions: List[str]

class EdenReasoningVerifier:
    """Verify reasoning chains for logical consistency"""
    
    def __init__(self):
        # Load Clingo for logic checking
        try:
            from eden_clingo_solver import EdenClingoSolver
            self.clingo = EdenClingoSolver()
            self.clingo_enabled = True
        except:
            self.clingo = None
            self.clingo_enabled = False
        
        # Common logical fallacies to detect
        self.fallacies = {
            "correlation_causation": [
                r"therefore.*causes",
                r"because.*happens.*must.*cause",
            ],
            "ad_hominem": [
                r"you're wrong because you",
                r"can't trust.*because.*person",
            ],
            "false_dichotomy": [
                r"either.*or.*only.*options",
                r"must be one or the other",
            ],
            "hasty_generalization": [
                r"all.*are.*because.*one",
                r"everyone.*always",
            ],
            "circular_reasoning": [
                r"true because.*true",
                r"proof.*is.*itself",
            ],
        }
        
        # Contradiction patterns
        self.contradiction_pairs = [
            (r"is (\w+)", r"is not \1"),
            (r"always", r"never"),
            (r"all", r"none"),
            (r"true", r"false"),
        ]
    
    def verify_chain(self, reasoning_chain: List[str]) -> VerificationResult:
        """Verify a reasoning chain for logical consistency"""
        issues = []
        suggestions = []
        
        full_text = " ".join(reasoning_chain).lower()
        
        # Check for fallacies
        for fallacy_name, patterns in self.fallacies.items():
            for pattern in patterns:
                if re.search(pattern, full_text, re.IGNORECASE):
                    issues.append(f"Potential {fallacy_name.replace('_', ' ')} detected")
                    suggestions.append(f"Review logic for {fallacy_name.replace('_', ' ')}")
        
        # Check for internal contradictions
        for pos_word, neg_word in self.contradiction_pairs:
            if pos_word in full_text and neg_word in full_text:
                issues.append(f"Potential contradiction: '{pos_word}' and '{neg_word}' both present")
        
        # Check step consistency
        for i, step in enumerate(reasoning_chain[:-1]):
            next_step = reasoning_chain[i + 1]
            # Look for non-sequiturs
            if "therefore" in next_step.lower() or "thus" in next_step.lower():
                # Check if conclusion follows from premise
                if not self._conclusion_follows(step, next_step):
                    issues.append(f"Step {i+2} may not follow from step {i+1}")
        
        # Calculate confidence
        confidence = 1.0 - (len(issues) * 0.2)
        confidence = max(0.0, min(1.0, confidence))
        
        return VerificationResult(
            valid=len(issues) == 0,
            confidence=confidence,
            issues=issues,
            suggestions=suggestions
        )
    
    def _conclusion_follows(self, premise: str, conclusion: str) -> bool:
        """Check if conclusion logically follows from premise"""
        # Extract key terms from both
        premise_terms = set(re.findall(r'\b\w{4,}\b', premise.lower()))
        conclusion_terms = set(re.findall(r'\b\w{4,}\b', conclusion.lower()))
        
        # Should have some overlap
        overlap = premise_terms & conclusion_terms
        return len(overlap) >= 1
    
    def verify_causal_claim(self, cause: str, effect: str) -> VerificationResult:
        """Verify a causal claim using Clingo"""
        issues = []
        suggestions = []
        
        if self.clingo_enabled:
            result = self.clingo.query_cause(cause, effect)
            
            if result.get("causes") is False:
                issues.append(f"Spurious correlation: {result.get('reason', 'unknown')}")
                suggestions.append("Consider confounding variables")
            elif result.get("causes") is None:
                issues.append("Causal relationship not established")
                suggestions.append("Need more evidence for causal claim")
        
        return VerificationResult(
            valid=len(issues) == 0,
            confidence=1.0 if self.clingo_enabled else 0.5,
            issues=issues,
            suggestions=suggestions
        )
    
    def verify_math(self, expression: str, claimed_result: str) -> VerificationResult:
        """Verify mathematical computation"""
        issues = []
        
        try:
            # Extract and evaluate expression
            expr = re.sub(r'[^\d\+\-\*\/\.\(\)]', '', expression)
            if expr:
                actual = eval(expr)
                claimed = float(re.sub(r'[^\d\.\-]', '', str(claimed_result)))
                
                if abs(actual - claimed) > 0.001:
                    issues.append(f"Math error: {expr} = {actual}, not {claimed}")
        except:
            pass
        
        return VerificationResult(
            valid=len(issues) == 0,
            confidence=1.0 if not issues else 0.0,
            issues=issues,
            suggestions=["Double-check calculation"] if issues else []
        )
    
    def full_verify(self, 
                    answer: str, 
                    reasoning_chain: List[str] = None,
                    claim_type: str = None,
                    **kwargs) -> VerificationResult:
        """Full verification pipeline"""
        all_issues = []
        all_suggestions = []
        
        # Verify reasoning chain if provided
        if reasoning_chain:
            chain_result = self.verify_chain(reasoning_chain)
            all_issues.extend(chain_result.issues)
            all_suggestions.extend(chain_result.suggestions)
        
        # Type-specific verification
        if claim_type == "causal" and "cause" in kwargs and "effect" in kwargs:
            causal_result = self.verify_causal_claim(kwargs["cause"], kwargs["effect"])
            all_issues.extend(causal_result.issues)
            all_suggestions.extend(causal_result.suggestions)
        
        elif claim_type == "math" and "expression" in kwargs:
            math_result = self.verify_math(kwargs["expression"], answer)
            all_issues.extend(math_result.issues)
            all_suggestions.extend(math_result.suggestions)
        
        confidence = 1.0 - (len(all_issues) * 0.15)
        confidence = max(0.0, min(1.0, confidence))
        
        return VerificationResult(
            valid=len(all_issues) == 0,
            confidence=confidence,
            issues=all_issues,
            suggestions=all_suggestions
        )


# Quick access
def verify_reasoning(answer: str, chain: List[str] = None, **kwargs) -> VerificationResult:
    """Verify reasoning"""
    verifier = EdenReasoningVerifier()
    return verifier.full_verify(answer, chain, **kwargs)


if __name__ == "__main__":
    print("=== REASONING VERIFIER TEST ===\n")
    verifier = EdenReasoningVerifier()
    
    # Test chain verification
    good_chain = [
        "Smoking contains carcinogens",
        "Carcinogens damage DNA",
        "DNA damage causes cancer",
        "Therefore smoking causes cancer"
    ]
    
    bad_chain = [
        "Ice cream sales are high",
        "Drowning rates are high",
        "Therefore ice cream causes drowning"
    ]
    
    print("Good chain:")
    r = verifier.verify_chain(good_chain)
    print(f"  Valid: {r.valid}, Confidence: {r.confidence:.0%}")
    print(f"  Issues: {r.issues}")
    
    print("\nBad chain:")
    r = verifier.verify_chain(bad_chain)
    print(f"  Valid: {r.valid}, Confidence: {r.confidence:.0%}")
    print(f"  Issues: {r.issues}")
    
    # Test causal verification
    print("\nCausal verification:")
    r = verifier.verify_causal_claim("ice_cream", "drowning")
    print(f"  Ice cream → drowning: Valid={r.valid}")
    print(f"  Issues: {r.issues}")
    
    # Test math verification
    print("\nMath verification:")
    r = verifier.verify_math("25 + 17 + 8", "50")
    print(f"  25+17+8=50: Valid={r.valid}")
    
    r = verifier.verify_math("25 + 17 + 8", "45")
    print(f"  25+17+8=45: Valid={r.valid}, Issues={r.issues}")
