#!/usr/bin/env python3
"""
Common Sense Reasoning
Physical & social intuition
"""

import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import json

class CommonSenseReasoning:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Loading models...")
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
        print("✅ Common sense ready!")
        
        # Knowledge base
        self.facts = self._load_common_sense()
    
    def _load_common_sense(self):
        """Load common sense facts"""
        return [
            # Physical
            "Objects fall down due to gravity",
            "Water flows downhill",
            "Fire is hot and can burn",
            "Ice is cold and melts when heated",
            "Heavy objects sink in water unless they float",
            "Glass breaks when dropped",
            "Metal conducts electricity",
            "Plants need sunlight to grow",
            
            # Social
            "People need food and water to survive",
            "Children are younger than adults",
            "Being polite makes conversations pleasant",
            "Loud noises can be annoying",
            "People sleep at night",
            "Smiling usually indicates happiness",
            "Helping others is generally good",
            "Money is used to buy things",
            
            # Causal
            "Wet things eventually dry",
            "Practice improves skills",
            "Exercise makes you tired",
            "Eating too much causes fullness",
            "Rain makes things wet",
            "Cutting things makes them smaller",
        ]
    
    def reason(self, query):
        """Apply common sense to query"""
        query_emb = self.encoder.encode(query, convert_to_tensor=True).cpu()
        
        # Find relevant facts
        scores = []
        for fact in self.facts:
            fact_emb = self.encoder.encode(fact, convert_to_tensor=True).cpu()
            score = F.cosine_similarity(query_emb.unsqueeze(0), fact_emb.unsqueeze(0)).item()
            scores.append((fact, score))
        
        scores.sort(key=lambda x: x[1], reverse=True)
        return scores[:3]
    
    def validate_statement(self, statement):
        """Check if statement makes common sense"""
        relevant = self.reason(statement)
        best_score = relevant[0][1] if relevant else 0
        return best_score > 0.5, best_score

def test_common_sense():
    print("\n" + "="*70)
    print("TESTING COMMON SENSE REASONING")
    print("="*70)
    
    cs = CommonSenseReasoning()
    
    tests = [
        ("What happens if you drop a glass?", "break", True),
        ("Why do people eat?", "survive", True),
        ("What happens to ice when heated?", "melt", True),
        ("Do rocks float on water?", "sink", False),
        ("Is fire cold?", "hot", False),
    ]
    
    passed = 0
    for query, keyword, _ in tests:
        results = cs.reason(query)
        print(f"\nQ: {query}")
        print(f"A: {results[0][0]}")
        
        if keyword in results[0][0].lower():
            print("✅ Correct!")
            passed += 1
        else:
            print("❌ Wrong")
    
    print(f"\n{passed}/{len(tests)} ({100*passed/len(tests):.1f}%)")
    
    if passed >= len(tests) * 0.6:
        print("\n✅ Common sense working!")
        return True
    return False

def main():
    if test_common_sense():
        print("\n✅ CAPABILITY #12 COMPLETE")

if __name__ == "__main__":
    main()
