#!/usr/bin/env python3
"""
Language Understanding System
Question answering, semantic comprehension, context reasoning
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm
import json
import random

# =============================================================================
# LANGUAGE UNDERSTANDING ENGINE
# =============================================================================

class LanguageUnderstanding:
    """
    Complete language understanding system using pre-trained models
    """
    
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        print("Loading language models...")
        
        # Sentence encoder for semantic similarity
        print("  Loading Sentence-BERT...")
        self.sentence_encoder = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
        
        print("✅ Models loaded!")
        
        # Knowledge base for QA
        self.knowledge = []
    
    def add_knowledge(self, text, metadata=None):
        """Add text to knowledge base"""
        embedding = self.sentence_encoder.encode(text, convert_to_tensor=True)
        
        self.knowledge.append({
            'text': text,
            'embedding': embedding.cpu(),
            'metadata': metadata or {}
        })
    
    def answer_question(self, question, top_k=3):
        """
        Answer question using knowledge base
        
        Returns:
            List of (text, relevance_score) tuples
        """
        if not self.knowledge:
            return []
        
        # Encode question
        q_emb = self.sentence_encoder.encode(question, convert_to_tensor=True).cpu()
        
        # Find most relevant knowledge
        scores = []
        for item in self.knowledge:
            sim = F.cosine_similarity(
                q_emb.unsqueeze(0),
                item['embedding'].unsqueeze(0)
            ).item()
            scores.append(sim)
        
        # Get top-k
        top_indices = np.argsort(scores)[-top_k:][::-1]
        
        results = []
        for idx in top_indices:
            results.append((self.knowledge[idx]['text'], float(scores[idx])))
        
        return results
    
    def semantic_similarity(self, text1, text2):
        """Calculate semantic similarity between two texts"""
        emb1 = self.sentence_encoder.encode(text1, convert_to_tensor=True)
        emb2 = self.sentence_encoder.encode(text2, convert_to_tensor=True)
        
        sim = F.cosine_similarity(emb1.unsqueeze(0), emb2.unsqueeze(0))
        return sim.item()
    
    def understand_context(self, texts):
        """
        Understand context from multiple texts
        Returns main themes/topics
        """
        if not texts:
            return []
        
        # Encode all texts
        embeddings = self.sentence_encoder.encode(texts, convert_to_tensor=True)
        
        # Find centroid (average embedding)
        centroid = embeddings.mean(dim=0)
        
        # Find texts closest to centroid (main themes)
        similarities = F.cosine_similarity(
            centroid.unsqueeze(0),
            embeddings
        )
        
        # Get top 3 most representative
        top_indices = similarities.argsort(descending=True)[:3]
        
        themes = [texts[idx] for idx in top_indices.cpu().numpy()]
        return themes
    
    def paraphrase_detection(self, text1, text2, threshold=0.7):
        """Detect if two texts are paraphrases"""
        sim = self.semantic_similarity(text1, text2)
        return sim >= threshold, sim
    
    def text_classification(self, text, categories):
        """
        Classify text into one of given categories
        
        Args:
            text: Text to classify
            categories: List of category descriptions
        
        Returns:
            (category, confidence)
        """
        text_emb = self.sentence_encoder.encode(text, convert_to_tensor=True).cpu()
        
        scores = []
        for category in categories:
            cat_emb = self.sentence_encoder.encode(category, convert_to_tensor=True).cpu()
            sim = F.cosine_similarity(
                text_emb.unsqueeze(0),
                cat_emb.unsqueeze(0)
            ).item()
            scores.append(sim)
        
        best_idx = np.argmax(scores)
        return categories[best_idx], float(scores[best_idx])

# =============================================================================
# READING COMPREHENSION
# =============================================================================

class ReadingComprehension:
    """
    Reading comprehension for document QA
    """
    
    def __init__(self, language_system):
        self.lang = language_system
        self.documents = {}
    
    def add_document(self, doc_id, text):
        """Add document for comprehension"""
        # Split into sentences for fine-grained retrieval
        sentences = [s.strip() for s in text.split('.') if s.strip()]
        
        self.documents[doc_id] = {
            'full_text': text,
            'sentences': sentences
        }
        
        # Add to knowledge base
        for sent in sentences:
            self.lang.add_knowledge(sent, {'doc_id': doc_id})
    
    def answer_from_document(self, question, doc_id=None):
        """Answer question from specific document or all documents"""
        if doc_id and doc_id not in self.documents:
            return None
        
        # Get relevant passages
        results = self.lang.answer_question(question, top_k=5)
        
        # Filter by doc_id if specified
        if doc_id:
            doc_sentences = self.documents[doc_id]['sentences']
            filtered = []
            for text, score in results:
                if text in doc_sentences:
                    filtered.append((text, score))
            return filtered
        
        return results

# =============================================================================
# TESTING
# =============================================================================

def test_language_understanding():
    """Test language understanding capabilities"""
    print("\n" + "="*70)
    print("TESTING LANGUAGE UNDERSTANDING")
    print("="*70)
    
    print("\nInitializing language system...")
    lang = LanguageUnderstanding()
    
    # Add knowledge
    print("\nAdding knowledge base...")
    knowledge_base = [
        "Python is a high-level programming language known for its simplicity.",
        "Machine learning is a subset of artificial intelligence.",
        "Neural networks are inspired by biological neurons in the brain.",
        "Deep learning uses multiple layers of neural networks.",
        "Transformers are a type of neural network architecture.",
        "GPT stands for Generative Pre-trained Transformer.",
        "The Eiffel Tower is located in Paris, France.",
        "JavaScript is commonly used for web development.",
        "Photosynthesis is how plants convert sunlight into energy.",
        "The human brain contains billions of neurons.",
    ]
    
    for text in knowledge_base:
        lang.add_knowledge(text)
        print(f"  ✓ {text[:60]}...")
    
    print(f"\n✅ Added {len(knowledge_base)} facts to knowledge base")
    
    # Test 1: Question Answering
    print("\n" + "="*70)
    print("TEST 1: QUESTION ANSWERING")
    print("="*70)
    
    test_questions = [
        ("What programming language is known for simplicity?", "Python"),
        ("What is machine learning?", "artificial intelligence"),
        ("What are neural networks inspired by?", "brain"),
        ("Where is the Eiffel Tower?", "Paris"),
        ("What is photosynthesis?", "plants")
    ]
    
    qa_passed = 0
    
    for question, expected_keyword in test_questions:
        print(f"\nQ: {question}")
        answers = lang.answer_question(question, top_k=1)
        
        if answers:
            answer_text, score = answers[0]
            print(f"A: {answer_text}")
            print(f"   (confidence: {score:.3f})")
            
            if expected_keyword.lower() in answer_text.lower():
                print("   ✅ Correct!")
                qa_passed += 1
            else:
                print("   ❌ Wrong")
        else:
            print("A: No answer found")
            print("   ❌ Wrong")
    
    # Test 2: Semantic Similarity
    print("\n" + "="*70)
    print("TEST 2: SEMANTIC SIMILARITY")
    print("="*70)
    
    similarity_tests = [
        ("The cat sat on the mat", "A feline rested on the rug", True),
        ("I love programming", "Coding is my passion", True),
        ("The weather is sunny", "It's raining heavily", False),
        ("Neural networks learn patterns", "AI systems detect patterns", True),
    ]
    
    sim_passed = 0
    
    for text1, text2, should_be_similar in similarity_tests:
        sim = lang.semantic_similarity(text1, text2)
        is_similar = sim > 0.5
        
        print(f"\nText 1: {text1}")
        print(f"Text 2: {text2}")
        print(f"Similarity: {sim:.3f}")
        
        if is_similar == should_be_similar:
            print("✅ Correct!")
            sim_passed += 1
        else:
            print("❌ Wrong")
    
    # Test 3: Text Classification
    print("\n" + "="*70)
    print("TEST 3: TEXT CLASSIFICATION")
    print("="*70)
    
    categories = [
        "technology and computers",
        "nature and environment",
        "food and cooking",
        "sports and fitness"
    ]
    
    classification_tests = [
        ("I built a new web application using React", "technology"),
        ("The forest ecosystem is diverse and complex", "nature"),
        ("I made pasta with tomato sauce for dinner", "food"),
        ("Running improves cardiovascular health", "sports")
    ]
    
    class_passed = 0
    
    for text, expected in classification_tests:
        category, confidence = lang.text_classification(text, categories)
        
        print(f"\nText: {text}")
        print(f"Classified as: {category} ({confidence:.3f})")
        
        if expected in category.lower():
            print("✅ Correct!")
            class_passed += 1
        else:
            print("❌ Wrong")
    
    # Test 4: Reading Comprehension
    print("\n" + "="*70)
    print("TEST 4: READING COMPREHENSION")
    print("="*70)
    
    rc = ReadingComprehension(lang)
    
    document = """
    Artificial intelligence has made remarkable progress in recent years.
    Deep learning models can now understand images, text, and speech.
    Transformers revolutionized natural language processing.
    Large language models like GPT can generate human-like text.
    AI is being applied in healthcare, finance, and education.
    """
    
    rc.add_document("ai_doc", document)
    print("Added AI document to comprehension system")
    
    comp_questions = [
        ("What revolutionized NLP?", "Transformers"),
        ("Where is AI being applied?", "healthcare"),
        ("What can deep learning understand?", "images")
    ]
    
    comp_passed = 0
    
    for question, expected in comp_questions:
        print(f"\nQ: {question}")
        answers = rc.answer_from_document(question, "ai_doc")
        
        if answers:
            answer_text, score = answers[0]
            print(f"A: {answer_text}")
            
            if expected.lower() in answer_text.lower():
                print("   ✅ Correct!")
                comp_passed += 1
            else:
                print("   ❌ Wrong")
        else:
            print("   No answer found")
            print("   ❌ Wrong")
    
    # Results
    print("\n" + "="*70)
    print("FINAL RESULTS")
    print("="*70)
    
    total_tests = len(test_questions) + len(similarity_tests) + len(classification_tests) + len(comp_questions)
    total_passed = qa_passed + sim_passed + class_passed + comp_passed
    
    print(f"\nQuestion Answering: {qa_passed}/{len(test_questions)}")
    print(f"Semantic Similarity: {sim_passed}/{len(similarity_tests)}")
    print(f"Text Classification: {class_passed}/{len(classification_tests)}")
    print(f"Reading Comprehension: {comp_passed}/{len(comp_questions)}")
    print(f"\nTotal: {total_passed}/{total_tests} ({100*total_passed/total_tests:.1f}%)")
    
    if total_passed >= total_tests * 0.8:
        print("\n✅ EXCELLENT - Language understanding working!")
        return True
    elif total_passed >= total_tests * 0.6:
        print("\n✅ GOOD - Strong language capabilities!")
        return True
    else:
        print("\n⚠️ Needs improvement")
        return False

def main():
    success = test_language_understanding()
    
    print("\n" + "="*70)
    if success:
        print("✅ LANGUAGE UNDERSTANDING: WORKING")
        print("\nCapabilities:")
        print("  1. Question answering from knowledge base")
        print("  2. Semantic similarity detection")
        print("  3. Text classification")
        print("  4. Reading comprehension")
        print("  5. Context understanding")
        print("\n✅ CAPABILITY #8 COMPLETE")
        print("\n📊 8/9 CAPABILITIES DONE - ONE TO GO!")

if __name__ == "__main__":
    main()
