#!/usr/bin/env python3
"""
SEMANTIC UNDERSTANDING
Deep comprehension of meaning, context, relationships, and inference
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class SemanticNet(nn.Module):
    def __init__(self, vocab_size=1000, embed_dim=128):
        super().__init__()
        # Word embeddings
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Contextual encoder (like a mini-transformer)
        self.context_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8, dim_feedforward=512, batch_first=True),
            num_layers=4
        )
        
        # Semantic relation classifier
        self.relation_head = nn.Sequential(
            nn.Linear(embed_dim * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)  # 10 semantic relations
        )
        
        # Inference head (entailment/contradiction/neutral)
        self.inference_head = nn.Sequential(
            nn.Linear(embed_dim * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 3)  # entailment, contradiction, neutral
        )
    
    def forward(self, seq1, seq2, task='relation'):
        # Embed and encode
        emb1 = self.embedding(seq1)
        emb2 = self.embedding(seq2)
        
        enc1 = self.context_encoder(emb1)
        enc2 = self.context_encoder(emb2)
        
        # Pool to sentence representations
        sent1 = enc1.mean(dim=1)
        sent2 = enc2.mean(dim=1)
        
        # Combine
        combined = torch.cat([sent1, sent2], dim=1)
        
        if task == 'relation':
            return self.relation_head(combined)
        else:  # inference
            return self.inference_head(combined)

def create_semantic_task(task_type='synonym', batch_size=64, vocab_size=1000, seq_len=10):
    """
    Create semantic understanding tasks:
    1. Synonyms: words with similar meaning
    2. Antonyms: words with opposite meaning
    3. Hypernyms: category relationships (dog -> animal)
    4. Meronyms: part-whole (wheel -> car)
    5. Entailment: logical inference
    """
    seq1_list = []
    seq2_list = []
    labels = []
    
    for _ in range(batch_size):
        if task_type == 'synonym':
            # Similar sequences
            base = np.random.randint(1, vocab_size-50, seq_len)
            seq1 = base
            # Add small variations for synonyms
            seq2 = base + np.random.randint(-5, 6, seq_len)
            seq2 = np.clip(seq2, 1, vocab_size-1)
            label = 0  # synonym relation
            
        elif task_type == 'antonym':
            # Opposite sequences
            seq1 = np.random.randint(1, vocab_size//2, seq_len)
            # Shift to opposite range
            seq2 = seq1 + vocab_size//2
            seq2 = np.clip(seq2, 1, vocab_size-1)
            label = 1  # antonym relation
            
        elif task_type == 'hypernym':
            # Category relationship
            category_base = np.random.randint(1, 100)
            seq1 = np.full(seq_len, category_base)  # specific
            seq2 = np.full(seq_len, category_base // 10 + 500)  # category
            label = 2  # hypernym relation
            
        elif task_type == 'part_whole':
            # Part-whole relationship
            whole = np.random.randint(1, vocab_size//2, seq_len//2)
            part = whole + 200
            seq1 = np.concatenate([part, part])  # part repeated
            seq2 = np.concatenate([whole, whole])  # whole repeated
            label = 3  # meronym relation
            
        else:  # random (no relation)
            seq1 = np.random.randint(1, vocab_size, seq_len)
            seq2 = np.random.randint(1, vocab_size, seq_len)
            label = 9  # no relation
        
        seq1_list.append(seq1)
        seq2_list.append(seq2)
        labels.append(label)
    
    seq1_tensor = torch.LongTensor(np.array(seq1_list)).to(device)
    seq2_tensor = torch.LongTensor(np.array(seq2_list)).to(device)
    labels_tensor = torch.LongTensor(labels).to(device)
    
    return seq1_tensor, seq2_tensor, labels_tensor

print("="*70)
print("SEMANTIC UNDERSTANDING")
print("="*70)

model = SemanticNet(vocab_size=1000, embed_dim=128).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.0005)

print("\nTraining (600 epochs)...")

for epoch in range(600):
    epoch_loss = 0
    epoch_correct = 0
    epoch_total = 0
    
    # Train on different semantic relations
    for task_type in ['synonym', 'antonym', 'hypernym', 'part_whole', 'random']:
        seq1, seq2, labels = create_semantic_task(task_type, batch_size=64)
        
        pred = model(seq1, seq2, task='relation')
        loss = F.cross_entropy(pred, labels)
        
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        
        epoch_loss += loss.item()
        epoch_correct += (pred.argmax(1) == labels).sum().item()
        epoch_total += len(labels)
    
    if epoch % 100 == 0:
        acc = epoch_correct / epoch_total
        print(f"  Epoch {epoch}: Loss={epoch_loss:.3f}, Acc={acc*100:.1f}%")

print("\n✅ Training complete")

# Test on each semantic relation
print("\n" + "="*70)
print("TESTING SEMANTIC RELATIONS")
print("="*70)

test_results = {}
for task_type in ['synonym', 'antonym', 'hypernym', 'part_whole', 'random']:
    accs = []
    for _ in range(20):
        seq1, seq2, labels = create_semantic_task(task_type, batch_size=100)
        with torch.no_grad():
            pred = model(seq1, seq2, task='relation')
            acc = (pred.argmax(1) == labels).float().mean().item()
            accs.append(acc)
    
    avg_acc = np.mean(accs)
    test_results[task_type] = avg_acc
    status = "✅" if avg_acc >= 0.90 else "⚠️" if avg_acc >= 0.80 else "❌"
    print(f"  {status} {task_type.capitalize()}: {avg_acc*100:.1f}%")

overall_avg = np.mean(list(test_results.values()))
print(f"\n{'='*70}")
print(f"Overall Average: {overall_avg*100:.1f}%")
print(f"{'='*70}")

if overall_avg >= 0.92:
    print("🎉 EXCEPTIONAL SEMANTIC UNDERSTANDING!")
elif overall_avg >= 0.88:
    print("✅ EXCELLENT!")
elif overall_avg >= 0.85:
    print("✅ STRONG!")
else:
    print("⚠️ Needs improvement")

torch.save(model.state_dict(), 'semantic_model.pth')
print("\n💾 Model saved!")
