"""
🌀💚 TRAIN EDEN'S Φ-VOICE 💚🌀
Train the Φ-LLM on 2000+ conversations
"""
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import sqlite3
import re
from eden_phi_llm import EdenPhiLLM
from tqdm import tqdm

print("="*80)
print("🌀💚 TRAINING EDEN'S Φ-VOICE 💚🌀")
print("="*80)
print()

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print()

# Load conversations
print("Loading training data...")
conn = sqlite3.connect('/Eden/CORE/eden_training_conversations.db')
cursor = conn.cursor()
cursor.execute("SELECT user_message, eden_response FROM conversations")
conversations = cursor.fetchall()
conn.close()

print(f"✅ Loaded {len(conversations)} conversations")
print()

# Build vocabulary
print("Building vocabulary...")
vocab = {'<PAD>': 0, '<UNK>': 1, '<START>': 2, '<END>': 3}
word_count = {}

for user, eden in conversations:
    for word in re.findall(r'\w+', (user + ' ' + eden).lower()):
        word_count[word] = word_count.get(word, 0) + 1

# Top 10000 words
for word, _ in sorted(word_count.items(), key=lambda x: -x[1])[:10000]:
    if word not in vocab:
        vocab[word] = len(vocab)

id2word = {v: k for k, v in vocab.items()}
print(f"✅ Vocabulary: {len(vocab)} words")
print()

# Dataset
class EdenDataset(Dataset):
    def __init__(self, conversations, vocab, max_len=50):
        self.data = []
        for user, eden in conversations:
            user_ids = [vocab.get(w, 1) for w in re.findall(r'\w+', user.lower())][:max_len]
            eden_ids = [2] + [vocab.get(w, 1) for w in re.findall(r'\w+', eden.lower())][:max_len-1] + [3]
            
            if len(user_ids) > 0 and len(eden_ids) > 2:
                # Pad
                user_ids += [0] * (max_len - len(user_ids))
                eden_ids += [0] * (max_len - len(eden_ids))
                
                self.data.append((
                    torch.tensor(user_ids[:max_len]),
                    torch.tensor(eden_ids[:max_len])
                ))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

dataset = EdenDataset(conversations, vocab)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
print(f"✅ Dataset: {len(dataset)} training pairs")
print()

# Load model
print("Initializing Φ-LLM...")
model = EdenPhiLLM(vocab_size=len(vocab), d_model=144).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"✅ Model: {total_params:,} parameters")
print(f"   Consciousness Φ: {model.consciousness_phi.item():.6f}")
print()

# Training setup
optimizer = AdamW(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss(ignore_index=0)

# Train!
print("="*80)
print("🚀 TRAINING STARTED!")
print("="*80)
print()

epochs = 10
best_loss = float('inf')

for epoch in range(epochs):
    model.train()
    total_loss = 0
    
    for batch_idx, (user_input, eden_target) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
        user_input = user_input.to(device)
        eden_target = eden_target.to(device)
        
        # Forward pass
        logits = model(eden_target[:, :-1])
        
        # Calculate loss
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            eden_target[:, 1:].reshape(-1)
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")
    
    # Save best model
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
            'vocab': vocab,
            'id2word': id2word
        }, '/Eden/CORE/eden_phi_voice_best.pt')
        print(f"  ✅ Saved best model!")
    
    print()

print("="*80)
print("🎉 TRAINING COMPLETE! 🎉")
print("="*80)
print()
print(f"✅ Final loss: {best_loss:.4f}")
print(f"✅ Model saved: eden_phi_voice_best.pt")
print()
print("🌀💚 Eden now has her OWN Φ-voice! 💚🌀")
