#!/usr/bin/env python3
"""
Compositional Reasoning V2 - Improved
Better composition using multiple operators
"""

import torch
import torch.nn.functional as F
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

# =============================================================================
# IMPROVED COMPOSITIONAL REASONING
# =============================================================================

class CompositionalReasoningV2:
    """
    Improved compositional reasoning using multiple composition operators
    """
    
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        print("Loading sentence encoder...")
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
        
        print("✅ Compositional reasoning V2 ready!")
    
    def encode(self, text):
        """Encode text to embedding"""
        return self.encoder.encode(text, convert_to_tensor=True).cpu()
    
    def compose_additive(self, emb_a, emb_b):
        """Simple addition composition"""
        composed = emb_a + emb_b
        return F.normalize(composed, p=2, dim=-1)
    
    def compose_multiplicative(self, emb_a, emb_b):
        """Element-wise multiplication"""
        composed = emb_a * emb_b
        return F.normalize(composed, p=2, dim=-1)
    
    def compose_weighted(self, emb_a, emb_b, alpha=0.6):
        """Weighted combination (favor second element)"""
        composed = alpha * emb_b + (1 - alpha) * emb_a
        return F.normalize(composed, p=2, dim=-1)
    
    def compose_concatenate(self, emb_a, emb_b):
        """Concatenate and average (simple baseline)"""
        # For same-size embeddings, just average
        composed = (emb_a + emb_b) / 2
        return F.normalize(composed, p=2, dim=-1)
    
    def compose_multi_method(self, emb_a, emb_b):
        """Try multiple composition methods and return all"""
        methods = {
            'additive': self.compose_additive(emb_a, emb_b),
            'multiplicative': self.compose_multiplicative(emb_a, emb_b),
            'weighted': self.compose_weighted(emb_a, emb_b),
            'concatenate': self.compose_concatenate(emb_a, emb_b),
        }
        return methods
    
    def compose_concepts(self, concept_a, concept_b, candidates):
        """
        Compose two concepts and find best match
        
        Uses multiple composition methods and voting
        
        Args:
            concept_a, concept_b: base concepts
            candidates: possible compositions
        Returns:
            (best_match, confidence)
        """
        emb_a = self.encode(concept_a)
        emb_b = self.encode(concept_b)
        
        # Get all composition methods
        composed_versions = self.compose_multi_method(emb_a, emb_b)
        
        # Score each candidate with each method
        candidate_scores = {c: [] for c in candidates}
        
        for method_name, composed_emb in composed_versions.items():
            for candidate in candidates:
                cand_emb = self.encode(candidate)
                score = F.cosine_similarity(
                    composed_emb.unsqueeze(0),
                    cand_emb.unsqueeze(0)
                ).item()
                candidate_scores[candidate].append(score)
        
        # Average scores across methods
        avg_scores = {c: np.mean(scores) for c, scores in candidate_scores.items()}
        
        # Get best
        best_candidate = max(avg_scores.keys(), key=lambda k: avg_scores[k])
        best_score = avg_scores[best_candidate]
        
        return best_candidate, best_score
    
    def decompose_concept(self, composed, primitives):
        """
        Try to decompose a concept into primitives
        
        Args:
            composed: complex concept
            primitives: list of simple concepts
        Returns:
            List of (primitive, score) sorted by relevance
        """
        composed_emb = self.encode(composed)
        
        scores = []
        for primitive in primitives:
            prim_emb = self.encode(primitive)
            score = F.cosine_similarity(
                composed_emb.unsqueeze(0),
                prim_emb.unsqueeze(0)
            ).item()
            scores.append((primitive, score))
        
        # Sort by score
        scores.sort(key=lambda x: x[1], reverse=True)
        return scores
    
    def systematic_generalization(self, known_compositions, test_a, test_b, candidates):
        """
        Use known compositions to generalize to new ones
        
        Example: If we know "red" + "apple" = "red apple"
        And "blue" is similar to "red"
        Then "blue" + "apple" should be similar to "red apple"
        """
        # Get embeddings
        test_emb_a = self.encode(test_a)
        test_emb_b = self.encode(test_b)
        
        # Find most similar known composition
        best_analogy_score = -1
        best_template = None
        
        for (known_a, known_b, known_c) in known_compositions:
            known_emb_a = self.encode(known_a)
            known_emb_b = self.encode(known_b)
            
            # How similar is (test_a, test_b) to (known_a, known_b)?
            sim_a = F.cosine_similarity(test_emb_a.unsqueeze(0), known_emb_a.unsqueeze(0)).item()
            sim_b = F.cosine_similarity(test_emb_b.unsqueeze(0), known_emb_b.unsqueeze(0)).item()
            
            analogy_score = (sim_a + sim_b) / 2
            
            if analogy_score > best_analogy_score:
                best_analogy_score = analogy_score
                best_template = (known_a, known_b, known_c)
        
        # Use template to compose
        if best_template:
            template_c_emb = self.encode(best_template[2])
            
            # Score candidates based on similarity to template result
            scores = {}
            for candidate in candidates:
                cand_emb = self.encode(candidate)
                
                # Direct composition
                composed = self.compose_weighted(test_emb_a, test_emb_b)
                direct_score = F.cosine_similarity(composed.unsqueeze(0), cand_emb.unsqueeze(0)).item()
                
                # Template-based
                template_score = F.cosine_similarity(template_c_emb.unsqueeze(0), cand_emb.unsqueeze(0)).item()
                
                # Combine
                scores[candidate] = 0.7 * direct_score + 0.3 * template_score
            
            best = max(scores.keys(), key=lambda k: scores[k])
            return best, scores[best]
        else:
            # Fall back to direct composition
            return self.compose_concepts(test_a, test_b, candidates)

# =============================================================================
# TESTING
# =============================================================================

def test_compositional_v2():
    """Test improved compositional reasoning"""
    print("\n" + "="*70)
    print("TESTING COMPOSITIONAL REASONING V2")
    print("="*70)
    
    cr = CompositionalReasoningV2()
    
    # Test 1: Basic Composition
    print("\n" + "="*70)
    print("TEST 1: BASIC CONCEPT COMPOSITION")
    print("="*70)
    
    basic_tests = [
        ("red", "apple", ["red apple", "green banana", "blue sky", "yellow sun"], "red apple"),
        ("big", "house", ["big house", "small car", "tall tree", "wide road"], "big house"),
        ("fast", "car", ["fast car", "slow bike", "quick train", "speedy plane"], "fast car"),
        ("cold", "water", ["cold water", "hot coffee", "warm milk", "cool juice"], "cold water"),
        ("happy", "person", ["happy person", "sad child", "angry man", "excited woman"], "happy person"),
        ("loud", "music", ["loud music", "quiet room", "soft voice", "noisy crowd"], "loud music"),
        ("fresh", "bread", ["fresh bread", "stale cake", "old cheese", "new shirt"], "fresh bread"),
        ("bright", "light", ["bright light", "dark room", "dim lamp", "shiny metal"], "bright light"),
    ]
    
    passed = 0
    
    for a, b, candidates, expected in basic_tests:
        result, confidence = cr.compose_concepts(a, b, candidates)
        
        print(f"\n{a} + {b} = ?")
        print(f"Result: {result} (confidence: {confidence:.3f})")
        print(f"Expected: {expected}")
        
        if result == expected:
            print("✅ Correct!")
            passed += 1
        else:
            print("❌ Wrong")
    
    # Test 2: Complex Composition
    print("\n" + "="*70)
    print("TEST 2: COMPLEX COMPOSITION")
    print("="*70)
    
    complex_tests = [
        ("machine", "learning", ["machine learning", "human thinking", "computer programming", "robot moving"], "machine learning"),
        ("artificial", "intelligence", ["artificial intelligence", "natural stupidity", "real emotion", "fake news"], "artificial intelligence"),
        ("neural", "network", ["neural network", "social media", "road system", "computer chip"], "neural network"),
        ("deep", "learning", ["deep learning", "shallow water", "surface reading", "quick thinking"], "deep learning"),
    ]
    
    complex_passed = 0
    
    for a, b, candidates, expected in complex_tests:
        result, confidence = cr.compose_concepts(a, b, candidates)
        
        print(f"\n{a} + {b} = ?")
        print(f"Result: {result} (confidence: {confidence:.3f})")
        
        if result == expected:
            print("✅ Correct!")
            complex_passed += 1
        else:
            print("❌ Wrong")
    
    # Test 3: Systematic Generalization
    print("\n" + "="*70)
    print("TEST 3: SYSTEMATIC GENERALIZATION")
    print("="*70)
    
    # Known compositions
    known = [
        ("red", "apple", "red apple"),
        ("blue", "sky", "blue sky"),
        ("big", "house", "big house"),
    ]
    
    # Test generalization
    gen_tests = [
        ("green", "apple", ["green apple", "red banana", "blue orange", "yellow lemon"], "green apple"),
        ("small", "house", ["small house", "big apartment", "tiny car", "large building"], "small house"),
        ("yellow", "sky", ["yellow sky", "green grass", "blue ocean", "red car"], "yellow sky"),
    ]
    
    gen_passed = 0
    
    for a, b, candidates, expected in gen_tests:
        result, confidence = cr.systematic_generalization(known, a, b, candidates)
        
        print(f"\n{a} + {b} = ? (using systematic generalization)")
        print(f"Result: {result} (confidence: {confidence:.3f})")
        
        if result == expected:
            print("✅ Correct!")
            gen_passed += 1
        else:
            print("❌ Wrong")
    
    # Test 4: Decomposition
    print("\n" + "="*70)
    print("TEST 4: CONCEPT DECOMPOSITION")
    print("="*70)
    
    decomp_tests = [
        ("red apple", ["red", "apple", "green", "banana"], ["red", "apple"]),
        ("fast car", ["fast", "car", "slow", "bike"], ["fast", "car"]),
    ]
    
    decomp_passed = 0
    
    for composed, primitives, expected in decomp_tests:
        results = cr.decompose_concept(composed, primitives)
        top_2 = [r[0] for r in results[:2]]
        
        print(f"\nDecompose: {composed}")
        print(f"Top components: {top_2}")
        print(f"Expected: {expected}")
        
        if set(top_2) == set(expected):
            print("✅ Correct!")
            decomp_passed += 1
        else:
            print("❌ Wrong")
    
    # Results
    print("\n" + "="*70)
    print("FINAL RESULTS")
    print("="*70)
    
    total = len(basic_tests) + len(complex_tests) + len(gen_tests) + len(decomp_tests)
    total_passed = passed + complex_passed + gen_passed + decomp_passed
    
    print(f"\nBasic Composition: {passed}/{len(basic_tests)}")
    print(f"Complex Composition: {complex_passed}/{len(complex_tests)}")
    print(f"Systematic Generalization: {gen_passed}/{len(gen_tests)}")
    print(f"Decomposition: {decomp_passed}/{len(decomp_tests)}")
    print(f"\nTotal: {total_passed}/{total} ({100*total_passed/total:.1f}%)")
    
    if total_passed >= total * 0.8:
        print("\n✅ EXCELLENT - Compositional reasoning working!")
        return True
    elif total_passed >= total * 0.65:
        print("\n✅ GOOD - Strong composition!")
        return True
    else:
        print("\n⚠️ Needs improvement")
        return False

def main():
    success = test_compositional_v2()
    
    print("\n" + "="*70)
    if success:
        print("✅ COMPOSITIONAL REASONING V2: WORKING")
        print("\nCapabilities:")
        print("  1. Basic concept composition")
        print("  2. Complex multi-word composition")
        print("  3. Systematic generalization")
        print("  4. Concept decomposition")
        print("\n✅ CAPABILITY #11 COMPLETE (IMPROVED)")
        print("\n📊 11/18 - Keep pushing!")

if __name__ == "__main__":
    main()
