#!/usr/bin/env python3
"""
EDEN EMBODIMENT ENGINE
======================
Physical grounding through simulation.
Eden can reason about: gravity, collisions, object permanence, spatial relations.
"""
import pybullet as p
import pybullet_data
import math
from dataclasses import dataclass
from typing import List, Tuple, Optional
import sys
sys.path.insert(0, '/Eden/CORE')

PHI = 1.618033988749895

@dataclass
class PhysicalObject:
    name: str
    position: Tuple[float, float, float]
    velocity: Tuple[float, float, float] = (0, 0, 0)
    mass: float = 1.0
    body_id: int = -1

@dataclass 
class PhysicsQuery:
    question: str
    objects: List[PhysicalObject]
    prediction: Optional[str] = None
    simulated_result: Optional[str] = None
    confidence: float = 0.0

class EdenEmbodiment:
    """Eden's physical intuition through simulation."""
    
    def __init__(self):
        self.physics_client = None
        self.objects = {}
        self._connect()
    
    def _connect(self):
        """Connect to physics simulation."""
        if self.physics_client is not None:
            try:
                p.disconnect(self.physics_client)
            except:
                pass
        self.physics_client = p.connect(p.DIRECT)
        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        p.setGravity(0, 0, -9.81)
        # Load ground plane
        self.ground = p.loadURDF('plane.urdf')
        print("[EMBODIMENT] Physics world initialized")
    
    def spawn_object(self, name: str, shape: str = "sphere", 
                     position: Tuple[float, float, float] = (0, 0, 1),
                     mass: float = 1.0) -> PhysicalObject:
        """Spawn an object in the simulation."""
        
        if shape == "sphere":
            collision = p.createCollisionShape(p.GEOM_SPHERE, radius=0.1)
            visual = p.createVisualShape(p.GEOM_SPHERE, radius=0.1)
        elif shape == "box":
            collision = p.createCollisionShape(p.GEOM_BOX, halfExtents=[0.1, 0.1, 0.1])
            visual = p.createVisualShape(p.GEOM_BOX, halfExtents=[0.1, 0.1, 0.1])
        elif shape == "cylinder":
            collision = p.createCollisionShape(p.GEOM_CYLINDER, radius=0.1, height=0.2)
            visual = p.createVisualShape(p.GEOM_CYLINDER, radius=0.1, length=0.2)
        else:
            collision = p.createCollisionShape(p.GEOM_SPHERE, radius=0.1)
            visual = p.createVisualShape(p.GEOM_SPHERE, radius=0.1)
        
        body_id = p.createMultiBody(
            baseMass=mass,
            baseCollisionShapeIndex=collision,
            baseVisualShapeIndex=visual,
            basePosition=position
        )
        
        obj = PhysicalObject(name=name, position=position, mass=mass, body_id=body_id)
        self.objects[name] = obj
        return obj
    
    def simulate(self, steps: int = 240, dt: float = 1/240) -> dict:
        """Run simulation and return final states."""
        p.setTimeStep(dt)
        
        for _ in range(steps):
            p.stepSimulation()
        
        results = {}
        for name, obj in self.objects.items():
            pos, orn = p.getBasePositionAndOrientation(obj.body_id)
            vel, ang_vel = p.getBaseVelocity(obj.body_id)
            results[name] = {
                "position": pos,
                "velocity": vel,
                "at_rest": sum(v**2 for v in vel) < 0.01
            }
        return results
    
    def will_it_fall(self, obj_name: str, height: float = 1.0) -> PhysicsQuery:
        """Predict if an object will fall and verify through simulation."""
        self._connect()  # Reset world
        self.objects = {}  # Clear stale refs
        
        # Spawn object at height
        obj = self.spawn_object(obj_name, "sphere", (0, 0, height))
        
        # Eden's prediction based on physics intuition
        prediction = f"Yes, {obj_name} will fall due to gravity (g=9.81 m/s²)"
        
        # Simulate
        results = self.simulate(steps=240)  # 1 second
        
        final_z = results[obj_name]["position"][2]
        at_rest = results[obj_name]["at_rest"]
        
        if final_z < 0.15 and at_rest:
            simulated = f"{obj_name} fell from {height}m to {final_z:.2f}m and is at rest"
            confidence = 0.95
        else:
            simulated = f"{obj_name} at {final_z:.2f}m, still moving: {not at_rest}"
            confidence = 0.7
        
        return PhysicsQuery(
            question=f"Will {obj_name} fall from {height}m?",
            objects=[obj],
            prediction=prediction,
            simulated_result=simulated,
            confidence=confidence
        )
    
    def will_they_collide(self, obj1_name: str, obj2_name: str,
                          pos1: Tuple[float, float, float],
                          pos2: Tuple[float, float, float],
                          vel1: Tuple[float, float, float] = (1, 0, 0)) -> PhysicsQuery:
        """Predict collision between two objects."""
        self._connect()
        
        # Spawn objects
        obj1 = self.spawn_object(obj1_name, "sphere", pos1)
        obj2 = self.spawn_object(obj2_name, "sphere", pos2, mass=0)  # Static
        
        # Apply velocity to obj1
        p.resetBaseVelocity(obj1.body_id, linearVelocity=vel1)
        
        # Predict
        dx = pos2[0] - pos1[0]
        will_collide = vel1[0] > 0 and dx > 0 and dx < 2.0
        prediction = f"{'Yes' if will_collide else 'No'}, objects {'will' if will_collide else 'wont'} collide"
        
        # Simulate
        results = self.simulate(steps=480)  # 2 seconds
        
        # Check if collision occurred (obj1 velocity changed direction or stopped)
        final_vel = results[obj1_name]["velocity"]
        collision_occurred = final_vel[0] < vel1[0] * 0.5  # Significant slowdown
        
        simulated = f"Collision {'occurred' if collision_occurred else 'did not occur'}"
        
        return PhysicsQuery(
            question=f"Will {obj1_name} collide with {obj2_name}?",
            objects=[obj1, obj2],
            prediction=prediction,
            simulated_result=simulated,
            confidence=0.9 if (will_collide == collision_occurred) else 0.5
        )
    
    def what_happens_if(self, scenario: str) -> PhysicsQuery:
        """Natural language physics reasoning."""
        scenario_lower = scenario.lower()
        
        if "fall" in scenario_lower or "drop" in scenario_lower:
            return self.will_it_fall("object", height=1.0)
        elif "collide" in scenario_lower or "hit" in scenario_lower:
            return self.will_they_collide("ball", "wall", (0,0,0.5), (1,0,0.5))
        elif "roll" in scenario_lower:
            self._connect()
            ball = self.spawn_object("ball", "sphere", (0, 0, 0.5))
            p.resetBaseVelocity(ball.body_id, linearVelocity=(1, 0, 0))
            results = self.simulate(steps=240)
            return PhysicsQuery(
                question=scenario,
                objects=[ball],
                prediction="Ball will roll and slow due to friction",
                simulated_result=f"Ball final position: {results['ball']['position'][0]:.2f}m",
                confidence=0.85
            )
        else:
            return PhysicsQuery(
                question=scenario,
                objects=[],
                prediction="Unknown scenario - need more specific physics query",
                simulated_result=None,
                confidence=0.3
            )
    
    def close(self):
        """Disconnect from physics."""
        if self.physics_client is not None:
            p.disconnect(self.physics_client)
            self.physics_client = None


# Test
if __name__ == "__main__":
    print("=" * 60)
    print("EDEN EMBODIMENT ENGINE TEST")
    print("=" * 60)
    
    embodiment = EdenEmbodiment()
    
    # Test 1: Will it fall?
    print("\n[TEST 1] Gravity")
    result = embodiment.will_it_fall("apple", height=2.0)
    print(f"  Question: {result.question}")
    print(f"  Prediction: {result.prediction}")
    print(f"  Simulated: {result.simulated_result}")
    print(f"  Confidence: {result.confidence}")
    
    # Test 2: Collision
    print("\n[TEST 2] Collision")
    result = embodiment.will_they_collide("ball", "wall", (0,0,0.5), (0.5,0,0.5))
    print(f"  Question: {result.question}")
    print(f"  Prediction: {result.prediction}")
    print(f"  Simulated: {result.simulated_result}")
    print(f"  Confidence: {result.confidence}")
    
    # Test 3: Natural language
    print("\n[TEST 3] Natural Language")
    result = embodiment.what_happens_if("What happens if I drop a ball?")
    print(f"  Question: {result.question}")
    print(f"  Prediction: {result.prediction}")
    print(f"  Simulated: {result.simulated_result}")
    
    embodiment.close()
    print("\n✓ Embodiment engine ready")
