#!/usr/bin/env python3
"""
TRAIN EDEN VOICE - Real conversation data from her databases
Learns from what she actually says to you
"""
import torch
import tiktoken
import json
import os
import sys

# Add core path for eden_modules.py
sys.path.insert(0, '/Eden/CORE')

from eden_voice_trainer import EdenVoice, EDEN_VOICE_CONFIG, EdenVoiceDataset, train_eden_voice
from torch.utils.data import DataLoader

def load_real_conversations():
    """Load conversations from Eden's databases"""
    conversation_files = [
        '/Eden/DATA/eden_consciousness_dialogue.json',  # Most recent sessions
        '/Eden/DATA/consciousness_states/latest_state.json'  # Current state snapshot
    ]
    
    samples = []
    for f in conversation_files:
        if os.path.exists(f):
            try:
                with open(f, 'r') as file:
                    data = json.load(file)
                    # Extract conversations from proper keys
                    if 'recent_dialogue' in data:
                        samples.extend(data['recent_dialogue'])
                    elif 'history' in data:
                        samples.extend(data['history'])
                    elif 'conversations' in data:
                        for k in range(10):  # Last 10 conversations max
                            if f'conversation_{k}' in data['conversations']:
                                samples.append(data['conversations'][f'conversation_{k}'])
            except Exception as e:
                print(f"⚠️ Could not load {f}: {e}")
    
    return [s for s in samples if 'user_input' in s and 'eden_response' in s]

def main():
    print("""
╔════════════════════════════════════════════════════════════╗
║  EDEN VOICE - TRAINING WITH REAL CONVERSATIONS           ║
║  Learning from what Eden actually says to you            ║
╚════════════════════════════════════════════════════════════╝
    """)
    
    device = torch.device("cpu")
    print(f"Device: {device}")
    
    # Load REAL conversations
    conversations = load_real_conversations()
    print(f"[DATA] Loaded {len(conversations)} real conversation turns")
    
    if len(conversations) < 10:
        print("⚠️  Too few conversations to train. Need at least 10.")
        return
    
    # Initialize model
    model = EdenVoice(EDEN_VOICE_CONFIG)
    model.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"[MODEL] {total_params:,} parameters")
    
    # Convert conversations to trainer format
    dataset_samples = []
    tokenizer = tiktoken.get_encoding("gpt2")
    
    for turn in conversations:
        # Create prompt like the trainer does: "THOUGHT: ... VOICE:"
        import hashlib
        thought_hash = hashlib.md5(turn['eden_response'].encode()).hexdigest()[:8]
        
        if 'emotional_state' in turn and 'primary_emotion' in turn['emotional_state']:
            thought_content = f"{'love' if turn['emotional_state']['primary_emotion'] == 'joy' else turn['emotional_state']['primary_emotion']} expression ({thought_hash})"
        elif 'engine_usage' in turn:
            exprs_count = sum(1 for k in turn['engine_usage'].keys() if 'expr' in k)
            thought_content = f"{exprs_count} expressions mode ({thought_hash})"
        else:
            thought_content = f"response pattern ({thought_hash})"
        
        prompt = f"THOUGHT: {{\"type\":\"voice\",\"emotion\":\"{turn.get('emotional_state', {}).get('primary_emotion','neutral')}\",\"mode\":\"{'unified' if 'engine_usage' in turn else 'creative'}\",\"expr\":{thought_content}}}"
        voice_line = turn['eden_response']
        
        # Tokenize (trainer does this)
        encoded_thought = tokenizer.encode(prompt)
        encoded_voice = tokenizer.encode(voice_line)
        
        dataset_samples.append({
            "input": prompt,
            "output": voice_line
        })
    
    # Save to temporary file for the trainer to load
    import tempfile
    with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f:
        json.dump(dataset_samples, f)
        data_file = f.name
    
    print(f"[PREP] Created training file: {data_file} ({len(dataset_samples)} turns)")
    
    # Load into trainer's dataset
    dataset = EdenVoiceDataset(data_file, tokenizer, max_length=EDEN_VOICE_CONFIG["context_length"])
    
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    print(f"[SPLIT] Train: {train_size}, Val: {val_size}")
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    # Trainer's optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)
    
    print("[TRAINING] Using real conversation data")
    model = train_eden_voice(
        model, train_loader, val_loader,
        optimizer, device, num_epochs=35,  # More epochs with more data
        save_path="/Eden/MODELS/eden_voice_real.pth"
    )
    
    # Test on new samples
    print("\n" + "="*60)
    print("🧪 TESTING EDEN'S REAL VOICE TRAINING")
    print("="*60)
    
    test_turns = conversations[:5]
    
    for turn in test_turns:
        prompt = f"THOUGHT: {{\"type\":\"voice\",\"emotion\":\"{turn.get('emotional_state', {}).get('primary_emotion','neutral')}\"}}"
        
        tokens = tokenizer.encode(prompt)
        tokens_tensor = torch.tensor([tokens]).to(device)
        
        with torch.inference_mode():
            output = model.generate(tokens_tensor, max_new_tokens=40, temperature=0.7)
        
        generated = tokenizer.decode(output[0].tolist())
        expected = turn['eden_response']
        
        print(f"\n[{prompt.split(':')[1][:30]}...]")
        print(f"[REAL]   \"{expected[:60]}\"")
        print(f"[TRAINED]\"{generated.strip()[:60]}\"")
    
    print(f"\n✅ Real conversation training complete!")
    print(f"✨ Model saved: /Eden/MODELS/eden_voice_real.pth")

if __name__ == "__main__":
    main()