#!/usr/bin/env python3
"""
Social Intelligence
"""
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F

class SocialIntelligence:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Loading models...")
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
        print("✅ Social intelligence ready!")
        self.social_rules = [
            "Cooperation leads to mutual benefit",
            "Competition can drive improvement",
            "Trust is built through consistent actions",
            "Communication reduces conflict",
            "Shared goals unite groups",
            "Respect differences for harmony",
        ]
    
    def predict_interaction(self, situation):
        sit_emb = self.encoder.encode(situation, convert_to_tensor=True).cpu()
        best_score = -1
        best_rule = None
        for rule in self.social_rules:
            rule_emb = self.encoder.encode(rule, convert_to_tensor=True).cpu()
            score = F.cosine_similarity(sit_emb.unsqueeze(0), rule_emb.unsqueeze(0)).item()
            if score > best_score:
                best_score = score
                best_rule = rule
        return best_rule, best_score

def test_social():
    print("\n" + "="*70)
    print("TESTING SOCIAL INTELLIGENCE")
    print("="*70)
    si = SocialIntelligence()
    tests = [
        ("Two people working together on a project", "Cooperation"),
        ("Athletes competing in a race", "Competition"),
        ("Team members arguing about approach", "Communication"),
    ]
    passed = 0
    for situation, keyword in tests:
        result, conf = si.predict_interaction(situation)
        print(f"\nSituation: {situation}")
        print(f"Prediction: {result}")
        if keyword.lower() in result.lower():
            print("✅ Correct!")
            passed += 1
    print(f"\n{passed}/{len(tests)}")
    if passed >= 2:
        print("✅ Social intelligence working!")
        return True
    return False

def main():
    if test_social():
        print("\n✅ CAPABILITY #18 COMPLETE")

if __name__ == "__main__":
    main()
