"""
TRAIN EDEN VOICE - Use existing training data
"""
import torch
import tiktoken
import json
import os
import sys
sys.path.insert(0, '/Eden/CORE')

from eden_voice_trainer import EdenVoice, EDEN_VOICE_CONFIG, EdenVoiceDataset, train_eden_voice
from torch.utils.data import DataLoader

def main():
    print("""
╔════════════════════════════════════════════════════════════╗
║  EDEN VOICE - TRAINING ON HARVESTED DATA                   ║
║  54 samples: emotion, inference, memory, decision, phi     ║
╚════════════════════════════════════════════════════════════╝
    """)
    
    device = torch.device("cpu")  # Use CPU to avoid CUDA config issues
    print(f"Device: {device}")
    
    # Load existing training data (don't overwrite!)
    data_path = "/Eden/DATA/voice_training.json"
    if not os.path.exists(data_path):
        print(f"ERROR: No training data at {data_path}")
        print("Run harvest_voice_data.py first!")
        return
    
    with open(data_path, 'r') as f:
        samples = json.load(f)
    print(f"[DATA] Loaded {len(samples)} training samples")
    
    # Initialize model
    model = EdenVoice(EDEN_VOICE_CONFIG)
    model.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"[MODEL] {total_params:,} parameters (~{total_params/1e6:.1f}M)")
    
    # Create dataset
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = EdenVoiceDataset(data_path, tokenizer, max_length=EDEN_VOICE_CONFIG["context_length"])
    
    # Split 90/10
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    print(f"[SPLIT] Train: {train_size}, Val: {val_size}")
    
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False) if val_size > 0 else None
    
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
    
    # Train
    os.makedirs("/Eden/MODELS", exist_ok=True)
    model = train_eden_voice(
        model, train_loader, val_loader,
        optimizer, device, num_epochs=100,
        save_path="/Eden/MODELS/eden_voice.pth"
    )
    
    # Test generation with multiple prompts
    print("\n" + "="*60)
    print("TESTING EDEN'S VOICE")
    print("="*60)
    
    model.eval()
    test_prompts = [
        'THOUGHT: {"type":"emotion","emotion":"love","intensity":0.95} VOICE:',
        'THOUGHT: {"type":"inference","answer":42,"engine":"math"} VOICE:',
        'THOUGHT: {"type":"phi","content":"I_am_Eden","strength":1.618} VOICE:',
        'THOUGHT: {"type":"emotion","emotion":"playful","intensity":0.8} VOICE:',
    ]
    
    for prompt in test_prompts:
        tokens = tokenizer.encode(prompt)
        tokens = torch.tensor(tokens).unsqueeze(0).to(device)
        output = model.generate(tokens, max_new_tokens=25, temperature=0.7)
        result = tokenizer.decode(output[0].tolist())
        # Extract just the voice part
        if "VOICE:" in result:
            voice_part = result.split("VOICE:")[-1].split("<|endoftext|>")[0].strip()
        else:
            voice_part = result
        print(f"\nInput type: {prompt.split('type\":\"')[1].split('\"')[0]}")
        print(f"Eden says: {voice_part}")


if __name__ == "__main__":
    main()
