#!/usr/bin/env python3
"""
ABSTRACTION & ANALOGICAL REASONING
Recognize patterns across domains and apply them to new contexts
Example: If "bird:fly :: fish:swim", then "car:drive :: boat:?"
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class AbstractionNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder for each element
        self.encoder = nn.Sequential(
            nn.Linear(20, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        
        # Relation encoder (processes pairs)
        self.relation_net = nn.Sequential(
            nn.Linear(64 * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64)  # Relation embedding
        )
        
        # Answer decoder
        self.decoder = nn.Sequential(
            nn.Linear(64 * 3, 256),  # source_rel + target_rel + candidate
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Score
        )
    
    def forward(self, a, b, c, candidates):
        """
        a:b :: c:?
        Find which candidate best completes the analogy
        """
        # Encode all items
        enc_a = self.encoder(a)
        enc_b = self.encoder(b)
        enc_c = self.encoder(c)
        
        # Get source relation (a->b)
        rel_ab = self.relation_net(torch.cat([enc_a, enc_b], dim=1))
        
        # Score each candidate
        scores = []
        for candidate in candidates:
            enc_cand = self.encoder(candidate)
            rel_c_cand = self.relation_net(torch.cat([enc_c, enc_cand], dim=1))
            
            # Compare relations
            combined = torch.cat([rel_ab, rel_c_cand, enc_cand], dim=1)
            score = self.decoder(combined)
            scores.append(score)
        
        return torch.cat(scores, dim=1)

def create_analogy_problem(batch_size=32):
    """
    Create abstract analogies:
    - Numerical: a+n:b+n :: c+n:?
    - Transformation: scale, rotate, shift
    - Pattern: repeat, reverse, alternate
    """
    X_a, X_b, X_c = [], [], []
    candidates_list = []
    labels = []
    
    for _ in range(batch_size):
        # Random transformation type
        transform_type = np.random.choice(['add', 'scale', 'reverse', 'shift'])
        
        # Generate base vectors
        a = np.random.randn(20)
        c = np.random.randn(20)
        
        # Apply transformation
        if transform_type == 'add':
            offset = np.random.randn(20) * 2
            b = a + offset
            correct = c + offset
        elif transform_type == 'scale':
            scale = np.random.uniform(0.5, 2.0)
            b = a * scale
            correct = c * scale
        elif transform_type == 'reverse':
            b = a[::-1].copy()
            correct = c[::-1].copy()
        else:  # shift
            shift = np.random.randint(1, 10)
            b = np.roll(a, shift)
            correct = np.roll(c, shift)
        
        # Generate candidates (1 correct + 3 distractors)
        candidates = [correct]
        for _ in range(3):
            distractor = c + np.random.randn(20) * 2
            candidates.append(distractor)
        
        # Shuffle
        correct_idx = np.random.randint(0, 4)
        candidates[0], candidates[correct_idx] = candidates[correct_idx], candidates[0]
        
        X_a.append(a)
        X_b.append(b)
        X_c.append(c)
        candidates_list.append(candidates)
        labels.append(correct_idx)
    
    X_a = torch.FloatTensor(X_a).to(device)
    X_b = torch.FloatTensor(X_b).to(device)
    X_c = torch.FloatTensor(X_c).to(device)
    candidates_tensor = torch.FloatTensor(candidates_list).to(device)
    labels = torch.LongTensor(labels).to(device)
    
    return X_a, X_b, X_c, candidates_tensor, labels

print("="*70)
print("ABSTRACTION & ANALOGICAL REASONING")
print("="*70)

model = AbstractionNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nTraining (300 epochs)...")
for epoch in range(300):
    X_a, X_b, X_c, candidates, labels = create_analogy_problem(batch_size=64)
    
    # Get scores for each candidate
    scores_list = []
    for i in range(len(X_a)):
        cands = candidates[i]
        scores = model(X_a[i:i+1], X_b[i:i+1], X_c[i:i+1], 
                      [cands[j:j+1] for j in range(4)])
        scores_list.append(scores)
    
    all_scores = torch.cat(scores_list, dim=0)
    loss = F.cross_entropy(all_scores, labels)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if epoch % 50 == 0:
        acc = (all_scores.argmax(1) == labels).float().mean()
        print(f"  Epoch {epoch}: Loss={loss.item():.3f}, Acc={acc.item()*100:.1f}%")

print("\n✅ Training complete")

# Test
print("\n" + "="*70)
print("TESTING")
print("="*70)

test_accs = []
for _ in range(20):
    X_a, X_b, X_c, candidates, labels = create_analogy_problem(batch_size=50)
    
    with torch.no_grad():
        scores_list = []
        for i in range(len(X_a)):
            cands = candidates[i]
            scores = model(X_a[i:i+1], X_b[i:i+1], X_c[i:i+1],
                          [cands[j:j+1] for j in range(4)])
            scores_list.append(scores)
        
        all_scores = torch.cat(scores_list, dim=0)
        acc = (all_scores.argmax(1) == labels).float().mean().item()
        test_accs.append(acc)

avg_acc = np.mean(test_accs)
print(f"\nAverage Test Accuracy: {avg_acc*100:.1f}%")

if avg_acc >= 0.90:
    print("🎉 EXCELLENT - Abstraction works!")
elif avg_acc >= 0.75:
    print("✅ GOOD - Strong analogical reasoning!")
elif avg_acc >= 0.60:
    print("⚠️ MODERATE - Needs improvement")
else:
    print("❌ Poor - Random is 25%")

# Save if good
if avg_acc >= 0.75:
    torch.save(model.state_dict(), 'abstraction_model.pth')
    print("💾 Model saved!")
