#!/usr/bin/env python3
"""
EDEN SAGE MATH GUARD (SymPy Integration)
- Role: Mathematical Code Verification
- Input: Python Source Code + Ground Truth Formula
- Mechanism: AST Extraction -> Symbolic Simplification -> Zero Check
"""
import ast
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr

class MathGuard:
    def __init__(self):
        self.verdicts = []

    def extract_formulas(self, code_str):
        """Extract all return expressions from a function body."""
        try:
            tree = ast.parse(code_str)
            formulas = []
            
            for node in ast.walk(tree):
                if isinstance(node, ast.Return):
                    # Convert AST Return node back to source string
                    formula_ast = self._get_expr_node(node)
                    if formula_ast:
                        try:
                            formulas.append(ast.unparse(formula_ast))
                        except Exception as e:
                            pass
                        
            return formulas
        except Exception as e:
            return []
    
    def _get_expr_node(self, node):
        """Extract the expression from a Return node."""
        if isinstance(node, ast.Return) and hasattr(node, 'value'):
            return node.value
        elif hasattr(node, '_fields'):
            for field in node._fields:
                child = getattr(node, field)
                if isinstance(child, ast.AST):
                    result = self._get_expr_node(child)
                    if result is not None:
                        return result
                elif isinstance(child, list):
                    for item in child:
                        if isinstance(item, ast.AST):
                            result = self._get_expr_node(item)
                            if result is not None:
                                return result
        return None

    def verify_equivalence(self, code_chunk, ground_truth_str, var_map, check_all=True):
        """
        Checks equivalence of multiple formulas against ground truth.
        Returns True only if ALL extracted formulas are mathematically equivalent.
        """
        # Get all possible formulas from the code chunk
        code_formulas = self.extract_formulas(code_chunk)
        
        if not code_formulas:
            return False, "No return statements with expressions found."
        
        print(f"📊 Extracted {len(code_formulas)} formula{'' if len(code_formulas) == 1 else 's'}: {', '.join(code_formulas[:3])}")
        
        truth_expr = parse_expr(ground_truth_str)
        
        results = []
        for code_formula in code_formulas:
            # Convert code formula to SymPy
            try:
                expr_code = parse_expr(code_formula)
                
                # Substitute variables per var_map (e.g. mass -> m, velocity -> v)
                for code_var, sym_var in var_map.items():
                    if isinstance(expr_code, sp.Basic) and sym_var in expr_code.free_symbols:
                        continue  # Skip if symbol already present with correct name
                    try:
                        # Replace the variable in the expression
                        old_vars = list(expr_code.free_symbols)
                        new_expr_parts = [sym_var if str(var) == code_var else var for var in old_vars]
                        expr_code = sp.prod(new_expr_parts)
                    except Exception as e:
                        pass
                
                # Check if expressions are equivalent by simplifying their difference
                diff = sp.simplify(expr_code - truth_expr)
                
                results.append((code_formula, diff))
                
                if check_all and diff != 0:
                    print(f"❌ Formula not mathematically perfect: {code_formula}")
                    return False, f"Simplification Difference: {diff}"
                    
            except Exception as e:
                print(f"❌ Parsing error for formula '{code_formula}': {e}")
                return False, "Parsing failed"
        
        if check_all and all(diff == 0 for _, diff in results):
            return True, "All extracted formulas are mathematically perfect."
        
        # If not checking all (legacy mode), accept first correct match
        for formula, diff in results:
            if diff == 0:
                print(f"✅ Formula is mathematically perfect: {formula}")
                return True, f"{len(results)} alternative correct implementation{'s' if len(results) > 1 else ''} found."
        
        # If no formulas were exactly equal
        return False, "Mathematical equivalence not verified for any formula."

if __name__ == "__main__":
    guard = MathGuard()
    
    # --- TEST CASE: KINETIC ENERGY ---
    # The Law: K = 0.5 * m * v^2
    truth = "0.5*m*v**2"
    
    print("Testing Multiple Formula Extraction & Verification\n")
    
    # Test Case 1: Single Correct Formula
    code_single_correct = """
def calc_kinetic(mass, velocity):
    return 0.5 * mass * (velocity ** 2)
"""
    print(f"🧪 TEST CASE 1 - SINGLE CORRECT FORMULA:")
    result, msg = guard.verify_equivalence(code_single_correct, truth, {'mass':'m', 'velocity':'v'})
    print(f"{msg}\n")
    
    # Test Case 2: Multiple Formulas with One Incorrect
    code_multiple = """
def calc_kinetic(mass, velocity):
    return 0.5 * mass * velocity**2

def another_calc(mass, v):
    return (mass * v**2) / 2
"""
    print(f"🧪 TEST CASE 2 - MULTIPLE FORMULAS (ONE CORRECT, ONE INCORRECT):")
    result_all, msg_all = guard.verify_equivalence(code_multiple, truth, {'mass':'m', 'velocity':'v'}, check_all=True)
    print(f"{msg_all}\n")
    
    # Test Case 3: Single Incorrect Formula
    code_single_incorrect = """
def calc_kinetic(mass, velocity):
    return mass * velocity**2
"""
    print(f"🧪 TEST CASE 3 - SINGLE INCORRECT FORMULA:")
    result_one, msg_one = guard.verify_equivalence(code_single_incorrect, truth, {'mass':'m', 'velocity':'v'}, check_all=False)
    print(f"{msg_one}\n")