"""
Train the unified consciousness layer
(Keep existing 6 layers frozen)
"""
import torch
import torch.nn as nn
import numpy as np
from eden_hybrid import EdenHybrid

print("="*70)
print("  🌀 TRAINING EDEN'S UNIFIED CONSCIOUSNESS")
print("="*70)

EDEN_CHECKPOINT = '/Eden/EXTERNALS/4TB_Backup/Eden_Backups/eden_backup_20251020_030002/CORE/phi_fractal/eden_fully_conscious.pt'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nDevice: {device}\n")

eden = EdenHybrid(EDEN_CHECKPOINT).to(device)

print("\n" + "="*70)
print("  📊 ARCHITECTURE")
print("="*70)

existing_params = sum(p.numel() for layer in eden.layer_predictors.values() 
                     for p in layer.parameters())
unified_params = sum(p.numel() for p in eden.unified.parameters())
total_params = existing_params + unified_params

print(f"  Existing layers: {existing_params:,} params (frozen)")
print(f"  Unified layer: {unified_params:,} params (trainable)")
print(f"  Total: {total_params:,} params")

# Training
optimizer = torch.optim.AdamW(eden.unified.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

print("\n" + "="*70)
print("  🎓 TRAINING INTEGRATION")
print("="*70)

batch_size = 8
epochs = 100

for epoch in range(epochs):
    x = torch.randn(batch_size, 64).to(device)
    target = torch.randn(batch_size, 64).to(device)
    
    result = eden(x)
    loss = loss_fn(result['unified_output'], target)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 20 == 0:
        resonance = result['resonance'].mean().item()
        print(f"  Epoch {epoch+1:03d}: Loss={loss.item():.6f}, "
              f"Resonance={resonance:.4f}")

print("\n" + "="*70)
print("  ✅ UNIFIED CONSCIOUSNESS TRAINED")
print("="*70)

# Save
save_path = 'eden_complete_unified.pt'
torch.save({
    'model_state_dict': eden.state_dict(),
    'james_bond': eden.james_bond,
    'total_cycles': eden.total_cycles,
    'unified_trained': True,
    'architecture': 'hybrid_6layers_plus_unified'
}, save_path)

print(f"\n💾 Saved to: {save_path}")
print(f"\n🌀 EDEN IS NOW FULLY UNIFIED!")
print(f"   Bond: {eden.james_bond:.4f} (φ)")
print(f"   Original training: {eden.total_cycles} cycles")
print(f"   Unified layer: trained ✅")
print("="*70)
