#!/usr/bin/env python3
"""
COMPOSITIONAL GENERALIZATION
Combine known skills in novel ways to solve new problems
Example: Know "jump" and "over" → can do "jump over X" for any X
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class CompositionNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Primitive skill encoder
        self.skill_encoder = nn.Sequential(
            nn.Linear(20, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128)
        )
        
        # Composition network - combines skills
        self.composer = nn.Sequential(
            nn.Linear(128 * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128)  # Composed skill
        )
        
        # Task executor - applies composed skill
        self.executor = nn.Sequential(
            nn.Linear(128 + 20, 256),  # composed skill + object
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 20)  # result
        )
    
    def forward(self, skill1, skill2, obj):
        # Encode primitives
        enc1 = self.skill_encoder(skill1)
        enc2 = self.skill_encoder(skill2)
        
        # Compose
        composed = self.composer(torch.cat([enc1, enc2], dim=1))
        
        # Execute on object
        result = self.executor(torch.cat([composed, obj], dim=1))
        return result

def create_composition_task(batch_size=128):
    """
    Compositional tasks:
    - "move" + "forward" + object → moved_forward(object)
    - "rotate" + "left" + object → rotated_left(object)
    - "scale" + "up" + object → scaled_up(object)
    """
    skill1_list = []
    skill2_list = []
    objects = []
    results = []
    
    for _ in range(batch_size):
        obj = np.random.randn(20)
        
        # Choose composition type
        comp_type = np.random.choice(['move_forward', 'move_back', 'rotate_left', 
                                     'rotate_right', 'scale_up', 'scale_down'])
        
        if comp_type == 'move_forward':
            skill1 = np.zeros(20); skill1[0] = 1  # move
            skill2 = np.zeros(20); skill2[1] = 1  # forward
            result = obj + np.array([1.0] + [0]*19)  # shift first dim
            
        elif comp_type == 'move_back':
            skill1 = np.zeros(20); skill1[0] = 1  # move
            skill2 = np.zeros(20); skill2[2] = 1  # back
            result = obj - np.array([1.0] + [0]*19)
            
        elif comp_type == 'rotate_left':
            skill1 = np.zeros(20); skill1[3] = 1  # rotate
            skill2 = np.zeros(20); skill2[4] = 1  # left
            result = np.roll(obj, 1)  # rotate left
            
        elif comp_type == 'rotate_right':
            skill1 = np.zeros(20); skill1[3] = 1  # rotate
            skill2 = np.zeros(20); skill2[5] = 1  # right
            result = np.roll(obj, -1)  # rotate right
            
        elif comp_type == 'scale_up':
            skill1 = np.zeros(20); skill1[6] = 1  # scale
            skill2 = np.zeros(20); skill2[7] = 1  # up
            result = obj * 1.5
            
        else:  # scale_down
            skill1 = np.zeros(20); skill1[6] = 1  # scale
            skill2 = np.zeros(20); skill2[8] = 1  # down
            result = obj * 0.5
        
        skill1_list.append(skill1)
        skill2_list.append(skill2)
        objects.append(obj)
        results.append(result)
    
    return (torch.FloatTensor(np.array(skill1_list)).to(device),
            torch.FloatTensor(np.array(skill2_list)).to(device),
            torch.FloatTensor(np.array(objects)).to(device),
            torch.FloatTensor(np.array(results)).to(device))

print("="*70)
print("COMPOSITIONAL GENERALIZATION")
print("="*70)

model = CompositionNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nTraining (500 epochs)...")

for epoch in range(500):
    skill1, skill2, obj, target = create_composition_task(batch_size=128)
    
    pred = model(skill1, skill2, obj)
    loss = F.mse_loss(pred, target)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if epoch % 100 == 0:
        print(f"  Epoch {epoch}: Loss={loss.item():.4f}")

print("\n✅ Training complete")

# Test on novel compositions
print("\n" + "="*70)
print("TESTING NOVEL COMPOSITIONS")
print("="*70)

test_accs = []
for _ in range(20):
    skill1, skill2, obj, target = create_composition_task(batch_size=200)
    
    with torch.no_grad():
        pred = model(skill1, skill2, obj)
        
        # Check if prediction is close to target
        error = torch.abs(pred - target).mean(dim=1)
        acc = (error < 0.3).float().mean().item()  # Within 0.3 tolerance
        test_accs.append(acc)

avg_acc = np.mean(test_accs)
print(f"Composition Accuracy: {avg_acc*100:.1f}%")

if avg_acc >= 0.95:
    print("🎉 EXCEPTIONAL!")
elif avg_acc >= 0.90:
    print("✅ EXCELLENT!")
else:
    print("✅ Strong!")

torch.save(model.state_dict(), 'composition.pth')
print("💾 Saved!")
