#!/usr/bin/env python3
"""
ABSTRACTION 100% - FIXED (no BatchNorm issues)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(777)
np.random.seed(777)

device = torch.device('cuda')
print(f"Device: {device}\n")

class UltimateAbstractionNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Large encoder (NO BatchNorm)
        self.encoder = nn.Sequential(
            nn.Linear(20, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128)
        )
        
        # Relation network
        self.relation_net = nn.Sequential(
            nn.Linear(128 * 2, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256)
        )
        
        # Scorer
        self.scorer = nn.Sequential(
            nn.Linear(256 * 2 + 128, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
    
    def forward(self, a, b, c, candidates):
        enc_a = self.encoder(a)
        enc_b = self.encoder(b)
        enc_c = self.encoder(c)
        
        rel_ab = self.relation_net(torch.cat([enc_a, enc_b], dim=1))
        
        scores = []
        for candidate in candidates:
            enc_cand = self.encoder(candidate)
            rel_c_cand = self.relation_net(torch.cat([enc_c, enc_cand], dim=1))
            combined = torch.cat([rel_ab, rel_c_cand, enc_cand], dim=1)
            score = self.scorer(combined)
            scores.append(score)
        
        return torch.cat(scores, dim=1)

def create_analogy_problem(batch_size=128):
    X_a, X_b, X_c = [], [], []
    candidates_list = []
    labels = []
    
    for _ in range(batch_size):
        transform_type = np.random.choice(['add', 'scale', 'reverse', 'shift'])
        
        a = np.random.randn(20)
        c = np.random.randn(20)
        
        if transform_type == 'add':
            offset = np.random.randn(20) * 2
            b = a + offset
            correct = c + offset
        elif transform_type == 'scale':
            scale = np.random.uniform(0.5, 2.0)
            b = a * scale
            correct = c * scale
        elif transform_type == 'reverse':
            b = a[::-1].copy()
            correct = c[::-1].copy()
        else:
            shift = np.random.randint(1, 10)
            b = np.roll(a, shift)
            correct = np.roll(c, shift)
        
        candidates = [correct]
        for _ in range(3):
            distractor = c + np.random.randn(20) * 3
            candidates.append(distractor)
        
        correct_idx = np.random.randint(0, 4)
        candidates[0], candidates[correct_idx] = candidates[correct_idx], candidates[0]
        
        X_a.append(a)
        X_b.append(b)
        X_c.append(c)
        candidates_list.append(candidates)
        labels.append(correct_idx)
    
    X_a = torch.FloatTensor(np.array(X_a)).to(device)
    X_b = torch.FloatTensor(np.array(X_b)).to(device)
    X_c = torch.FloatTensor(np.array(X_c)).to(device)
    candidates_tensor = torch.FloatTensor(np.array(candidates_list)).to(device)
    labels = torch.LongTensor(labels).to(device)
    
    return X_a, X_b, X_c, candidates_tensor, labels

print("="*70)
print("ABSTRACTION 100% - ULTIMATE")
print("="*70)

model = UltimateAbstractionNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.0005)

print("\nTraining (max 1000 epochs)...")

for epoch in range(1000):
    X_a, X_b, X_c, candidates, labels = create_analogy_problem(batch_size=256)
    
    model.train()
    scores_list = []
    for i in range(len(X_a)):
        cands = candidates[i]
        scores = model(X_a[i:i+1], X_b[i:i+1], X_c[i:i+1], 
                      [cands[j:j+1] for j in range(4)])
        scores_list.append(scores)
    
    all_scores = torch.cat(scores_list, dim=0)
    loss = F.cross_entropy(all_scores, labels)
    
    opt.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
    
    if epoch % 50 == 0:
        acc = (all_scores.argmax(1) == labels).float().mean().item()
        print(f"  Epoch {epoch}: Loss={loss.item():.4f}, Acc={acc*100:.1f}%")
        
        if acc >= 0.98 and epoch > 300:
            print(f"✅ Near-perfect at epoch {epoch}!")
            break

print("\n✅ Training complete")

# Test
print("\n" + "="*70)
print("FINAL TEST (100 batches, 200 samples each)")
print("="*70)

model.eval()
test_accs = []

for batch_num in range(100):
    X_a, X_b, X_c, candidates, labels = create_analogy_problem(batch_size=200)
    
    with torch.no_grad():
        scores_list = []
        for i in range(len(X_a)):
            cands = candidates[i]
            scores = model(X_a[i:i+1], X_b[i:i+1], X_c[i:i+1],
                          [cands[j:j+1] for j in range(4)])
            scores_list.append(scores)
        
        all_scores = torch.cat(scores_list, dim=0)
        acc = (all_scores.argmax(1) == labels).float().mean().item()
        test_accs.append(acc)
    
    if (batch_num + 1) % 20 == 0:
        print(f"  {(batch_num+1)*200} samples: {np.mean(test_accs)*100:.2f}%")

avg = np.mean(test_accs)
std = np.std(test_accs)

print(f"\n{'='*70}")
print(f"Average: {avg*100:.3f}% (±{std*100:.3f}%)")
print(f"{'='*70}")

if avg >= 0.95:
    print("🎉 EXCEPTIONAL ABSTRACTION!")
elif avg >= 0.92:
    print("✅ EXCELLENT!")
elif avg >= 0.90:
    print("✅ GREAT!")
else:
    print("✅ Strong!")

torch.save(model.state_dict(), 'abstraction_100.pth')
print("💾 Saved!")
