#!/usr/bin/env python3
"""
Analogical Reasoning
A:B :: C:D reasoning across domains
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import random

# =============================================================================
# ANALOGICAL REASONING SYSTEM
# =============================================================================

class AnalogyEncoder(nn.Module):
    """
    Encode relations for analogy reasoning
    Learns: relation(A,B) ≈ relation(C,D)
    """
    
    def __init__(self, embed_dim=384, hidden_dim=256):
        super().__init__()
        
        # Relation encoder
        self.relation_net = nn.Sequential(
            nn.Linear(embed_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
    
    def encode_relation(self, emb_a, emb_b):
        """Encode relation between A and B"""
        # Concatenate embeddings
        combined = torch.cat([emb_a, emb_b], dim=-1)
        relation = self.relation_net(combined)
        return F.normalize(relation, p=2, dim=-1)
    
    def forward(self, emb_a, emb_b, emb_c, emb_d):
        """
        Score how well A:B :: C:D
        """
        relation_ab = self.encode_relation(emb_a, emb_b)
        relation_cd = self.encode_relation(emb_c, emb_d)
        
        # Similarity of relations
        score = F.cosine_similarity(relation_ab, relation_cd, dim=-1)
        return score

class AnalogicalReasoning:
    """Complete analogical reasoning system"""
    
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        print("Loading models...")
        self.text_encoder = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
        
        embed_dim = 384
        self.analogy_model = AnalogyEncoder(embed_dim=embed_dim).to(self.device)
        
        print("✅ Analogical reasoning ready!")
    
    def encode(self, text):
        """Encode text to embedding"""
        return self.text_encoder.encode(text, convert_to_tensor=True)
    
    def solve_analogy(self, a, b, c, candidates):
        """
        Solve A:B :: C:? given candidates
        
        Args:
            a, b, c: strings
            candidates: list of possible D values
        Returns:
            (best_d, confidence)
        """
        # Encode
        emb_a = self.encode(a).to(self.device)
        emb_b = self.encode(b).to(self.device)
        emb_c = self.encode(c).to(self.device)
        
        best_score = -1
        best_candidate = None
        
        with torch.no_grad():
            for candidate in candidates:
                emb_d = self.encode(candidate).to(self.device)
                
                # Add batch dimension
                score = self.analogy_model(
                    emb_a.unsqueeze(0),
                    emb_b.unsqueeze(0),
                    emb_c.unsqueeze(0),
                    emb_d.unsqueeze(0)
                ).item()
                
                if score > best_score:
                    best_score = score
                    best_candidate = candidate
        
        return best_candidate, best_score
    
    def train_on_analogies(self, analogies, epochs=50):
        """
        Train on analogy examples
        
        Args:
            analogies: List of (A, B, C, D) tuples
        """
        optimizer = torch.optim.Adam(self.analogy_model.parameters(), lr=0.001)
        
        print(f"\nTraining on {len(analogies)} analogies for {epochs} epochs...")
        
        for epoch in range(epochs):
            total_loss = 0
            random.shuffle(analogies)
            
            for a, b, c, d in tqdm(analogies, desc=f"Epoch {epoch+1}"):
                # Positive example
                emb_a = self.encode(a).to(self.device).unsqueeze(0)
                emb_b = self.encode(b).to(self.device).unsqueeze(0)
                emb_c = self.encode(c).to(self.device).unsqueeze(0)
                emb_d = self.encode(d).to(self.device).unsqueeze(0)
                
                optimizer.zero_grad()
                
                # Positive score (should be high)
                pos_score = self.analogy_model(emb_a, emb_b, emb_c, emb_d)
                
                # Negative example (random wrong D)
                wrong_d = random.choice([x[3] for x in analogies if x != (a, b, c, d)])
                emb_wrong = self.encode(wrong_d).to(self.device).unsqueeze(0)
                neg_score = self.analogy_model(emb_a, emb_b, emb_c, emb_wrong)
                
                # Contrastive loss: positive high, negative low
                loss = F.relu(0.5 - pos_score + neg_score)
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}: Loss: {total_loss/len(analogies):.4f}")
        
        print("✅ Training complete")
        
        # Save
        torch.save(self.analogy_model.state_dict(), 'analogy_model.pth')

# =============================================================================
# TRAINING DATA
# =============================================================================

def generate_analogy_data():
    """Generate training analogies"""
    analogies = [
        # Semantic analogies
        ("king", "queen", "man", "woman"),
        ("hot", "cold", "big", "small"),
        ("dog", "puppy", "cat", "kitten"),
        ("doctor", "hospital", "teacher", "school"),
        ("car", "road", "boat", "water"),
        ("day", "night", "summer", "winter"),
        ("happy", "sad", "good", "bad"),
        ("eat", "food", "drink", "water"),
        ("run", "fast", "walk", "slow"),
        ("book", "read", "music", "listen"),
        
        # Function analogies
        ("hammer", "nail", "screwdriver", "screw"),
        ("pen", "write", "brush", "paint"),
        ("key", "lock", "password", "account"),
        ("fuel", "car", "electricity", "computer"),
        ("teacher", "student", "coach", "athlete"),
        
        # Category analogies
        ("apple", "fruit", "carrot", "vegetable"),
        ("lion", "mammal", "eagle", "bird"),
        ("red", "color", "circle", "shape"),
        ("guitar", "instrument", "soccer", "sport"),
        ("novel", "literature", "painting", "art"),
        
        # Scale analogies
        ("inch", "foot", "minute", "hour"),
        ("small", "medium", "cold", "warm"),
        ("few", "many", "little", "much"),
        
        # Opposite analogies
        ("up", "down", "left", "right"),
        ("start", "end", "beginning", "finish"),
        ("question", "answer", "problem", "solution"),
    ]
    
    return analogies

# =============================================================================
# TESTING
# =============================================================================

def test_analogical_reasoning():
    """Test analogy solving"""
    print("\n" + "="*70)
    print("TESTING ANALOGICAL REASONING")
    print("="*70)
    
    # Initialize
    ar = AnalogicalReasoning()
    
    # Train
    training_data = generate_analogy_data()
    ar.train_on_analogies(training_data, epochs=50)
    
    # Load trained model
    ar.analogy_model.load_state_dict(torch.load('analogy_model.pth'))
    ar.analogy_model.eval()
    
    # Test cases
    print("\n" + "="*70)
    print("TEST: ANALOGY SOLVING")
    print("="*70)
    
    test_cases = [
        ("bird", "fly", "fish", ["swim", "run", "jump", "crawl"], "swim"),
        ("sun", "day", "moon", ["night", "morning", "noon", "evening"], "night"),
        ("tree", "forest", "star", ["sky", "ground", "ocean", "mountain"], "sky"),
        ("bread", "baker", "house", ["builder", "teacher", "doctor", "farmer"], "builder"),
        ("hand", "glove", "foot", ["shoe", "hat", "shirt", "pants"], "shoe"),
        ("brain", "think", "heart", ["beat", "eat", "sleep", "walk"], "beat"),
        ("eye", "see", "ear", ["hear", "touch", "smell", "taste"], "hear"),
        ("fire", "hot", "ice", ["cold", "wet", "hard", "soft"], "cold"),
    ]
    
    passed = 0
    
    for a, b, c, candidates, expected in test_cases:
        answer, confidence = ar.solve_analogy(a, b, c, candidates)
        
        print(f"\n{a} : {b} :: {c} : ?")
        print(f"Candidates: {candidates}")
        print(f"Answer: {answer} (confidence: {confidence:.3f})")
        print(f"Expected: {expected}")
        
        if answer == expected:
            print("✅ Correct!")
            passed += 1
        else:
            print("❌ Wrong")
    
    # Results
    print("\n" + "="*70)
    print("RESULTS")
    print("="*70)
    
    accuracy = 100 * passed / len(test_cases)
    print(f"\nAccuracy: {passed}/{len(test_cases)} ({accuracy:.1f}%)")
    
    if accuracy >= 75:
        print("\n✅ EXCELLENT - Analogical reasoning working!")
        return True
    elif accuracy >= 60:
        print("\n✅ GOOD - Strong analogy capabilities!")
        return True
    else:
        print("\n⚠️ Needs improvement")
        return False

def main():
    success = test_analogical_reasoning()
    
    print("\n" + "="*70)
    if success:
        print("✅ ANALOGICAL REASONING: WORKING")
        print("\nCapabilities:")
        print("  1. Semantic analogies")
        print("  2. Functional analogies")
        print("  3. Category analogies")
        print("  4. Cross-domain transfer")
        print("\n✅ CAPABILITY #10 COMPLETE")
        print("\n📊 10/18 - Keep going!")

if __name__ == "__main__":
    main()
