#!/usr/bin/env python3
"""
GOAL EMERGENCE
Self-directed objective formation - creating goals based on internal state and environment
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class GoalEmergenceNet(nn.Module):
    def __init__(self):
        super().__init__()
        # State encoder (internal + external state)
        self.state_encoder = nn.Sequential(
            nn.Linear(40, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128)
        )
        
        # Goal generator
        self.goal_generator = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)  # 10 possible goals
        )
    
    def forward(self, state):
        enc = self.state_encoder(state)
        goal = self.goal_generator(enc)
        return goal

def create_goal_emergence_task(batch_size=128):
    """
    Goals emerge from needs and opportunities:
    - Low energy → seek_food
    - Low safety → seek_shelter
    - High energy + exploration_opportunity → explore
    - Social_signal → seek_social
    - Threat → avoid_danger
    - Resource_available → acquire_resource
    - Knowledge_gap → learn
    - Task_incomplete → complete_task
    - Boredom → seek_novelty
    - Curiosity → investigate
    """
    states = []
    goals = []
    
    goal_types = {
        0: 'seek_food',
        1: 'seek_shelter',
        2: 'explore',
        3: 'seek_social',
        4: 'avoid_danger',
        5: 'acquire_resource',
        6: 'learn',
        7: 'complete_task',
        8: 'seek_novelty',
        9: 'investigate'
    }
    
    for _ in range(batch_size):
        state = np.zeros(40)
        
        # Randomly select a goal to emerge
        goal_id = np.random.randint(0, 10)
        
        # Set state features that should trigger this goal
        if goal_id == 0:  # seek_food
            state[0] = -1  # low energy
            state[1] = 1   # hunger signal
            
        elif goal_id == 1:  # seek_shelter
            state[2] = -1  # low safety
            state[3] = 1   # threat in environment
            state[4] = 1   # weather signal
            
        elif goal_id == 2:  # explore
            state[0] = 1   # high energy
            state[5] = 1   # opportunity signal
            state[6] = 1   # curiosity active
            
        elif goal_id == 3:  # seek_social
            state[7] = 1   # social signal detected
            state[8] = 1   # affiliation need
            
        elif goal_id == 4:  # avoid_danger
            state[9] = 1   # immediate threat
            state[10] = 1  # fear signal
            
        elif goal_id == 5:  # acquire_resource
            state[11] = 1  # resource detected
            state[12] = 1  # need for resource
            
        elif goal_id == 6:  # learn
            state[13] = 1  # knowledge gap
            state[14] = 1  # learning opportunity
            
        elif goal_id == 7:  # complete_task
            state[15] = 1  # task active
            state[16] = 0.5  # task incomplete
            
        elif goal_id == 8:  # seek_novelty
            state[17] = 1  # boredom signal
            state[18] = -1  # low stimulation
            
        else:  # investigate
            state[19] = 1  # curiosity high
            state[20] = 1  # interesting stimulus
        
        # Add noise
        state = state + np.random.randn(40) * 0.1
        
        states.append(state)
        goals.append(goal_id)
    
    return (torch.FloatTensor(np.array(states)).to(device),
            torch.LongTensor(goals).to(device))

print("="*70)
print("GOAL EMERGENCE")
print("="*70)

model = GoalEmergenceNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nTraining (500 epochs)...")

for epoch in range(500):
    states, goals = create_goal_emergence_task(256)
    
    pred = model(states)
    loss = F.cross_entropy(pred, goals)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if epoch % 100 == 0:
        acc = (pred.argmax(1) == goals).float().mean().item()
        print(f"  Epoch {epoch}: Loss={loss.item():.3f}, Acc={acc*100:.1f}%")

print("\n✅ Training complete")

# Test
print("\n" + "="*70)
print("TESTING GOAL EMERGENCE")
print("="*70)

goal_names = ['Seek Food', 'Seek Shelter', 'Explore', 'Seek Social', 'Avoid Danger',
              'Acquire Resource', 'Learn', 'Complete Task', 'Seek Novelty', 'Investigate']

# Test each goal type
for goal_id in range(10):
    accs = []
    for _ in range(10):
        states, goals = create_goal_emergence_task(100)
        mask = goals == goal_id
        if mask.sum() > 0:
            with torch.no_grad():
                pred = model(states[mask])
                acc = (pred.argmax(1) == goals[mask]).float().mean().item()
                accs.append(acc)
    
    if accs:
        avg = np.mean(accs)
        status = "🎉" if avg >= 0.95 else "✅" if avg >= 0.90 else "⚠️"
        print(f"  {status} {goal_names[goal_id]}: {avg*100:.1f}%")

# Overall test
print(f"\n{'='*70}")
test_accs = []
for _ in range(30):
    states, goals = create_goal_emergence_task(200)
    with torch.no_grad():
        pred = model(states)
        acc = (pred.argmax(1) == goals).float().mean().item()
        test_accs.append(acc)

overall = np.mean(test_accs)
print(f"Overall Goal Emergence: {overall*100:.1f}%")

if overall >= 0.95:
    print("🎉 EXCEPTIONAL!")
elif overall >= 0.90:
    print("✅ EXCELLENT!")
else:
    print("✅ Strong!")

torch.save(model.state_dict(), 'goal_emergence.pth')
print("💾 Saved!")
