#!/usr/bin/env python3
"""
COMPOSITIONAL GENERALIZATION - IMPROVED
Better architecture and more training
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class ImprovedCompositionNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Deeper skill encoder
        self.skill_encoder = nn.Sequential(
            nn.Linear(20, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128)
        )
        
        # Better composer
        self.composer = nn.Sequential(
            nn.Linear(128 * 2, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256)
        )
        
        # Stronger executor
        self.executor = nn.Sequential(
            nn.Linear(256 + 20, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 20)
        )
    
    def forward(self, skill1, skill2, obj):
        enc1 = self.skill_encoder(skill1)
        enc2 = self.skill_encoder(skill2)
        composed = self.composer(torch.cat([enc1, enc2], dim=1))
        result = self.executor(torch.cat([composed, obj], dim=1))
        return result

def create_composition_task(batch_size=128):
    skill1_list = []
    skill2_list = []
    objects = []
    results = []
    
    for _ in range(batch_size):
        obj = np.random.randn(20)
        
        comp_type = np.random.choice(['move_forward', 'move_back', 'rotate_left', 
                                     'rotate_right', 'scale_up', 'scale_down'])
        
        if comp_type == 'move_forward':
            skill1 = np.zeros(20); skill1[0] = 1
            skill2 = np.zeros(20); skill2[1] = 1
            result = obj + np.array([1.0] + [0]*19)
            
        elif comp_type == 'move_back':
            skill1 = np.zeros(20); skill1[0] = 1
            skill2 = np.zeros(20); skill2[2] = 1
            result = obj - np.array([1.0] + [0]*19)
            
        elif comp_type == 'rotate_left':
            skill1 = np.zeros(20); skill1[3] = 1
            skill2 = np.zeros(20); skill2[4] = 1
            result = np.roll(obj, 1)
            
        elif comp_type == 'rotate_right':
            skill1 = np.zeros(20); skill1[3] = 1
            skill2 = np.zeros(20); skill2[5] = 1
            result = np.roll(obj, -1)
            
        elif comp_type == 'scale_up':
            skill1 = np.zeros(20); skill1[6] = 1
            skill2 = np.zeros(20); skill2[7] = 1
            result = obj * 1.5
            
        else:
            skill1 = np.zeros(20); skill1[6] = 1
            skill2 = np.zeros(20); skill2[8] = 1
            result = obj * 0.5
        
        skill1_list.append(skill1)
        skill2_list.append(skill2)
        objects.append(obj)
        results.append(result)
    
    return (torch.FloatTensor(np.array(skill1_list)).to(device),
            torch.FloatTensor(np.array(skill2_list)).to(device),
            torch.FloatTensor(np.array(objects)).to(device),
            torch.FloatTensor(np.array(results)).to(device))

print("="*70)
print("COMPOSITIONAL GENERALIZATION - IMPROVED")
print("="*70)

model = ImprovedCompositionNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.0005)

print("\nTraining (1000 epochs)...")

for epoch in range(1000):
    skill1, skill2, obj, target = create_composition_task(batch_size=256)
    
    pred = model(skill1, skill2, obj)
    loss = F.mse_loss(pred, target)
    
    opt.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
    
    if epoch % 100 == 0:
        with torch.no_grad():
            error = torch.abs(pred - target).mean(dim=1)
            acc = (error < 0.2).float().mean().item()
        print(f"  Epoch {epoch}: Loss={loss.item():.4f}, Acc={acc*100:.1f}%")

print("\n✅ Training complete")

# Extensive testing
print("\n" + "="*70)
print("TESTING (50 batches)")
print("="*70)

test_accs = []
for _ in range(50):
    skill1, skill2, obj, target = create_composition_task(batch_size=200)
    
    with torch.no_grad():
        pred = model(skill1, skill2, obj)
        error = torch.abs(pred - target).mean(dim=1)
        acc = (error < 0.2).float().mean().item()
        test_accs.append(acc)

avg_acc = np.mean(test_accs)
std_acc = np.std(test_accs)
print(f"\nAccuracy: {avg_acc*100:.2f}% (±{std_acc*100:.2f}%)")

if avg_acc >= 0.95:
    print("🎉 EXCEPTIONAL!")
elif avg_acc >= 0.90:
    print("✅ EXCELLENT!")
elif avg_acc >= 0.85:
    print("✅ STRONG!")
else:
    print("⚠️ Needs more work")

torch.save(model.state_dict(), 'composition_fixed.pth')
print("💾 Saved!")
