"""
EDEN CLINGO SOLVER - True Symbolic Logic (100% Deterministic)
Created: Jan 26, 2026
Uses Answer Set Programming for guaranteed correct reasoning
"""
import clingo
from typing import List, Dict, Any, Optional
from dataclasses import dataclass

@dataclass
class ClingoResult:
    answers: List[Dict[str, Any]]
    satisfiable: bool
    optimal: Optional[int] = None
    
class EdenClingoSolver:
    """Answer Set Programming solver for 100% deterministic logic"""
    
    def __init__(self):
        self.base_facts = []
        self.rules = []
        self._load_knowledge_base()
        
    def _load_knowledge_base(self):
        """Load core facts and rules"""
        # CAUSAL FACTS - EXPANDED Jan 26 2026
        self.base_facts.extend([
            # Direct causes - Medical/Health
            "causes(smoking, lung_cancer).",
            "causes(smoking, heart_disease).",
            "causes(alcohol, liver_damage).",
            "causes(sun_exposure, skin_cancer).",
            "causes(high_sodium, hypertension).",
            "causes(sedentary_lifestyle, obesity).",
            "causes(stress, high_blood_pressure).",
            "causes(virus, infection).",
            "causes(bacteria, food_poisoning).",
            
            # Direct causes - Physics
            "causes(rain, wet_ground).",
            "causes(fire, heat).",
            "causes(gravity, falling).",
            "causes(friction, heat).",
            "causes(electricity, light).",
            "causes(cold, freezing).",
            "causes(heat, evaporation).",
            "causes(pressure, compression).",
            
            # Direct causes - Behavior
            "causes(studying, good_grades).",
            "causes(practice, skill_improvement).",
            "causes(exercise, fitness).",
            "causes(overeating, weight_gain).",
            "causes(sleep_deprivation, fatigue).",
            "causes(dehydration, thirst).",
            "causes(saving, wealth).",
            "causes(learning, knowledge).",
            
            # Direct causes - Technology
            "causes(coding, software).",
            "causes(training, ai_model).",
            "causes(bugs, crashes).",
            "causes(optimization, speed).",
            
            # Confounders (spurious correlations)
            "confounder(ice_cream, drowning, temperature).",
            "confounder(shoe_size, reading_ability, age).",
            "confounder(yellow_fingers, lung_cancer, smoking).",
            "confounder(rooster_crow, sunrise, earth_rotation).",
            "confounder(umbrella_sales, car_accidents, rain).",
            "confounder(shark_attacks, ice_cream_sales, summer).",
            "confounder(birth_rate, stork_population, rural_areas).",
            "confounder(chocolate, nobel_prizes, wealth).",
            "confounder(organic_food, autism, time).",
            "confounder(pirates, global_warming, industrialization).",
            
            # Properties
            "fragile(glass). fragile(egg). fragile(ceramic). fragile(phone_screen).",
            "liquid(water). liquid(milk). liquid(oil). liquid(blood).",
            "hot(fire). hot(stove). hot(sun). hot(lava).",
            "cold(ice). cold(snow). cold(freezer). cold(space).",
            "heavy(rock). heavy(metal). heavy(water).",
            "light(feather). light(air). light(helium).",
        ])
        
        # CAUSAL RULES
        self.rules.extend([
            # Transitivity of causation
            "causes(X, Z) :- causes(X, Y), causes(Y, Z).",
            
            # Confounder detection
            "spurious_correlation(X, Y) :- confounder(X, Y, _).",
            "spurious_correlation(X, Y) :- confounder(Y, X, _).",
            "real_cause(X, Y) :- causes(X, Y), not spurious_correlation(X, Y).",
            
            # Physics rules
            "breaks(X) :- fragile(X), dropped(X).",
            "wet(X) :- liquid(L), spilled_on(L, X).",
            "burns(X) :- hot(H), touches(X, H).",
            
            # Analogies as relations
            "opposite(hot, cold). opposite(cold, hot).",
            "opposite(big, small). opposite(small, big).",
            "opposite(up, down). opposite(down, up).",
            "opposite(fast, slow). opposite(slow, fast).",
            
            "action(bird, fly). action(fish, swim). action(snake, slither).",
            "action(dog, bark). action(cat, meow). action(cow, moo).",
            "action(horse, gallop). action(frog, hop). action(kangaroo, jump).",
            "action(bee, buzz). action(lion, roar). action(wolf, howl).",
            
            "royalty_pair(king, queen). royalty_pair(prince, princess).",
            "royalty_pair(emperor, empress). royalty_pair(duke, duchess).",
            
            "tool_of(hammer, carpenter). tool_of(scalpel, surgeon).",
            "tool_of(brush, painter). tool_of(keyboard, programmer).",
            
            "part_of(wheel, car). part_of(leaf, tree).",
            "part_of(key, keyboard). part_of(pixel, screen).",
            
            "young_of(puppy, dog). young_of(kitten, cat).",
            "young_of(calf, cow). young_of(foal, horse).",
        ])
    
    def solve(self, query: str, additional_facts: List[str] = None) -> ClingoResult:
        """Solve a logic program and return all answer sets"""
        ctl = clingo.Control(["0", "--warn=none"])  # Find all models
        
        # Build program
        program = "\n".join(self.base_facts + self.rules)
        if additional_facts:
            program += "\n" + "\n".join(additional_facts)
        program += f"\n{query}"
        
        ctl.add("base", [], program)
        ctl.ground([("base", [])])
        
        answers = []
        with ctl.solve(yield_=True) as handle:
            for model in handle:
                atoms = [str(atom) for atom in model.symbols(shown=True)]
                answers.append({"atoms": atoms, "cost": model.cost})
            satisfiable = handle.get().satisfiable
            
        return ClingoResult(
            answers=answers,
            satisfiable=satisfiable,
            optimal=answers[0]["cost"] if answers and answers[0]["cost"] else None
        )
    
    def query_cause(self, x: str, y: str) -> Dict[str, Any]:
        """Query if X causes Y - check atoms directly"""
        x, y = x.lower().replace(" ", "_"), y.lower().replace(" ", "_")
        
        # Get all causal atoms with explicit #show
        result = self.solve("#show causes/2. #show confounder/3. #show spurious_correlation/2. #show real_cause/2.")
        if not result.satisfiable or not result.answers:
            return {"causes": None, "reason": "Could not solve", "confidence": 0.0, "verified": False}
        
        atoms = result.answers[0]["atoms"]
        
        # FIRST check direct causation
        direct_cause = f"causes({x},{y})"
        real_cause = f"real_cause({x},{y})"
        if direct_cause in atoms or real_cause in atoms:
            return {
                "causes": True,
                "reason": f"Yes, {x.replace('_', ' ')} causes {y.replace('_', ' ')} (verified by symbolic logic).",
                "confidence": 1.0,
                "verified": True
            }
        
        # THEN check for spurious correlation
        spurious1 = f"spurious_correlation({x},{y})"
        spurious2 = f"spurious_correlation({y},{x})"
        if spurious1 in atoms or spurious2 in atoms:
            # Find the confounder
            confounder = "a common factor"
            for atom in atoms:
                if atom.startswith(f"confounder({x},{y},") or atom.startswith(f"confounder({y},{x},"):
                    confounder = atom.split(",")[-1].replace(")", "")
            return {
                "causes": False,
                "reason": f"Spurious correlation! {x.replace('_', ' ')} and {y.replace('_', ' ')} are both influenced by {confounder}.",
                "confidence": 1.0,
                "verified": True
            }
        
        return {
            "causes": None,
            "reason": f"No causal relationship found between {x.replace('_', ' ')} and {y.replace('_', ' ')}.",
            "confidence": 1.0,
            "verified": True
        }
    
    def query_analogy(self, a: str, b: str, c: str) -> Dict[str, Any]:
        """Solve A:B :: C:? analogies - handles all relationship types"""
        a, b, c = a.lower(), b.lower(), c.lower()
        
        # All relation types to check
        relations = [
            ("opposite", "opposite"),
            ("action", "action"),
            ("royalty_pair", "royalty"),
            ("tool_of", "tool"),
            ("part_of", "part"),
            ("young_of", "young"),
        ]
        
        for rel_name, rel_type in relations:
            # Try A->B pattern
            result = self.solve(f"answer(D) :- {rel_name}({a}, {b}), {rel_name}({c}, D). #show answer/1.")
            for ans in result.answers:
                for atom in ans["atoms"]:
                    if "answer(" in atom:
                        d = atom.replace("answer(", "").replace(")", "")
                        return {"answer": d, "confidence": 1.0, "verified": True, "type": rel_type}
            
            # Try B->A pattern (for symmetric relations)
            result = self.solve(f"answer(D) :- {rel_name}({b}, {a}), {rel_name}(D, {c}). #show answer/1.")
            for ans in result.answers:
                for atom in ans["atoms"]:
                    if "answer(" in atom:
                        d = atom.replace("answer(", "").replace(")", "")
                        return {"answer": d, "confidence": 1.0, "verified": True, "type": rel_type}
        
        return {"answer": None, "confidence": 0.0, "verified": False, "type": "unknown"}
    
    def query_physics(self, scenario: str) -> Dict[str, Any]:
        """Query physics outcomes"""
        scenario = scenario.lower()
        
        # Parse scenario
        if "drop" in scenario:
            # Find object
            for obj in ["glass", "egg", "ceramic"]:
                if obj in scenario:
                    result = self.solve(f"dropped({obj}). #show breaks/1.")
                    if result.answers and any(f"breaks({obj})" in str(a) for a in result.answers):
                        return {
                            "outcome": f"The {obj} breaks.",
                            "reason": f"{obj} is fragile and dropped objects with fragility break.",
                            "confidence": 1.0,
                            "verified": True
                        }
        
        return {"outcome": "Unknown", "confidence": 0.5, "verified": False}


# Quick access function
def clingo_reason(question: str) -> Dict[str, Any]:
    """Main entry point for Clingo reasoning"""
    solver = EdenClingoSolver()
    q = question.lower()
    
    # Causal queries
    import re
    match = re.search(r'does\s+(.+?)\s+cause\s+(.+?)(?:\?|$)', q)
    if match:
        return solver.query_cause(match.group(1), match.group(2))
    
    # Analogy queries
    match = re.search(r'(\w+)\s*(?:is\s+to|:)\s*(\w+)\s*(?:as|::)\s*(\w+)\s*(?:is\s+to|:)\s*\??', q)
    if match:
        return solver.query_analogy(match.group(1), match.group(2), match.group(3))
    
    # Physics queries
    if any(w in q for w in ["drop", "fall", "break", "spill"]):
        return solver.query_physics(q)
    
    return {"answer": "Query type not recognized", "verified": False}


if __name__ == "__main__":
    print("=== CLINGO SOLVER TEST ===\n")
    solver = EdenClingoSolver()
    
    tests = [
        ("Does smoking cause lung cancer?", "query_cause", ("smoking", "lung_cancer")),
        ("Does ice cream cause drowning?", "query_cause", ("ice_cream", "drowning")),
        ("hot:cold :: big:?", "query_analogy", ("hot", "cold", "big")),
        ("bird:fly :: fish:?", "query_analogy", ("bird", "fly", "fish")),
        ("What if I drop a glass?", "query_physics", ("drop glass",)),
    ]
    
    for question, method, args in tests:
        print(f"❓ {question}")
        if method == "query_cause":
            result = solver.query_cause(*args)
        elif method == "query_analogy":
            result = solver.query_analogy(*args)
        else:
            result = solver.query_physics(args[0])
        
        print(f"   → {result}")
        print(f"   ✓ Verified: {result.get('verified', False)}")
        print()
