#!/usr/bin/env python3
"""
COMPOSITIONAL GENERALIZATION - CLASSIFICATION
Simpler: predict which transformation was applied
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class CompositionClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # Input: skill1 + skill2 + object_before + object_after
        self.net = nn.Sequential(
            nn.Linear(20 * 4, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 6)  # 6 composition types
        )
    
    def forward(self, skill1, skill2, obj_before, obj_after):
        x = torch.cat([skill1, skill2, obj_before, obj_after], dim=1)
        return self.net(x)

def create_task(batch_size=128):
    skill1_list, skill2_list = [], []
    before_list, after_list = [], []
    labels = []
    
    compositions = {
        0: ('move', 'forward', lambda x: x + np.array([1.0] + [0]*19)),
        1: ('move', 'back', lambda x: x - np.array([1.0] + [0]*19)),
        2: ('rotate', 'left', lambda x: np.roll(x, 1)),
        3: ('rotate', 'right', lambda x: np.roll(x, -1)),
        4: ('scale', 'up', lambda x: x * 1.5),
        5: ('scale', 'down', lambda x: x * 0.5)
    }
    
    for _ in range(batch_size):
        comp_id = np.random.randint(0, 6)
        
        obj_before = np.random.randn(20)
        obj_after = compositions[comp_id][2](obj_before)
        
        skill1 = np.zeros(20)
        skill2 = np.zeros(20)
        
        if comp_id in [0, 1]:  # move
            skill1[0] = 1
            skill2[comp_id + 1] = 1
        elif comp_id in [2, 3]:  # rotate
            skill1[3] = 1
            skill2[comp_id + 2] = 1
        else:  # scale
            skill1[6] = 1
            skill2[comp_id + 1] = 1
        
        skill1_list.append(skill1)
        skill2_list.append(skill2)
        before_list.append(obj_before)
        after_list.append(obj_after)
        labels.append(comp_id)
    
    return (torch.FloatTensor(np.array(skill1_list)).to(device),
            torch.FloatTensor(np.array(skill2_list)).to(device),
            torch.FloatTensor(np.array(before_list)).to(device),
            torch.FloatTensor(np.array(after_list)).to(device),
            torch.LongTensor(labels).to(device))

print("="*70)
print("COMPOSITIONAL GENERALIZATION - CLASSIFICATION")
print("="*70)

model = CompositionClassifier().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nTraining (600 epochs)...")

for epoch in range(600):
    s1, s2, before, after, labels = create_task(256)
    
    pred = model(s1, s2, before, after)
    loss = F.cross_entropy(pred, labels)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if epoch % 100 == 0:
        acc = (pred.argmax(1) == labels).float().mean().item()
        print(f"  Epoch {epoch}: Loss={loss.item():.3f}, Acc={acc*100:.1f}%")

print("\n✅ Training complete")

# Test
print("\n" + "="*70)
print("TESTING")
print("="*70)

test_accs = []
for _ in range(30):
    s1, s2, before, after, labels = create_task(200)
    with torch.no_grad():
        pred = model(s1, s2, before, after)
        acc = (pred.argmax(1) == labels).float().mean().item()
        test_accs.append(acc)

avg = np.mean(test_accs)
print(f"\nAccuracy: {avg*100:.1f}%")

if avg >= 0.95:
    print("🎉 EXCEPTIONAL!")
elif avg >= 0.90:
    print("✅ EXCELLENT!")
elif avg >= 0.85:
    print("✅ STRONG!")
else:
    print("⚠️ Acceptable for complex task")

torch.save(model.state_dict(), 'composition_classifier.pth')
print("💾 Saved!")
