#!/usr/bin/env python3
"""
Real Causal Reasoning with Graph Neural Networks
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
import numpy as np
import random
from tqdm import tqdm
import networkx as nx

# =============================================================================
# CAUSAL GRAPH
# =============================================================================

class CausalGraph:
    def __init__(self):
        self.nodes = []
        self.edges = []
        self.node_to_idx = {}
        self.confounders = []
    
    def add_node(self, name):
        if name not in self.node_to_idx:
            self.node_to_idx[name] = len(self.nodes)
            self.nodes.append(name)
    
    def add_edge(self, cause, effect):
        self.add_node(cause)
        self.add_node(effect)
        self.edges.append((self.node_to_idx[cause], self.node_to_idx[effect]))
    
    def add_confounder(self, confounder, var1, var2):
        self.add_node(confounder)
        self.add_edge(confounder, var1)
        self.add_edge(confounder, var2)
        self.confounders.append(self.node_to_idx[confounder])
    
    def to_edge_index(self):
        if not self.edges:
            return torch.zeros((2, 0), dtype=torch.long)
        edge_index = torch.tensor(self.edges, dtype=torch.long).t()
        return edge_index

def generate_causal_graph_scenarios():
    scenarios = []
    
    # Ice cream confounded
    g1 = CausalGraph()
    g1.add_confounder("temperature", "ice_cream", "drowning")
    scenarios.append({
        'graph': g1,
        'question_type': 'causation',
        'X': 'ice_cream',
        'Y': 'drowning',
        'answer': 'no',
        'explanation': 'Temperature confounds'
    })
    
    # Studying direct
    g2 = CausalGraph()
    g2.add_edge("studying", "grades")
    scenarios.append({
        'graph': g2,
        'question_type': 'causation',
        'X': 'studying',
        'Y': 'grades',
        'answer': 'yes',
        'explanation': 'Direct causation'
    })
    
    # Smoking causes both
    g3 = CausalGraph()
    g3.add_edge("smoking", "yellow_fingers")
    g3.add_edge("smoking", "lung_cancer")
    scenarios.append({
        'graph': g3,
        'question_type': 'causation',
        'X': 'yellow_fingers',
        'Y': 'lung_cancer',
        'answer': 'no',
        'explanation': 'Common cause smoking'
    })
    
    # Firefighters reverse
    g4 = CausalGraph()
    g4.add_edge("fire_size", "firefighters")
    g4.add_edge("fire_size", "damage")
    scenarios.append({
        'graph': g4,
        'question_type': 'causation',
        'X': 'firefighters',
        'Y': 'damage',
        'answer': 'no',
        'explanation': 'Common cause fire'
    })
    
    # Age confounds shoe size and reading
    g5 = CausalGraph()
    g5.add_confounder("age", "shoe_size", "reading")
    scenarios.append({
        'graph': g5,
        'question_type': 'causation',
        'X': 'shoe_size',
        'Y': 'reading',
        'answer': 'no',
        'explanation': 'Age confounds'
    })
    
    # More confounded examples
    confounded = [
        ("weather", "ice_cream", "sunburn"),
        ("wealth", "education", "health"),
        ("city_size", "crime", "traffic"),
        ("season", "heating", "flu")
    ]
    
    for conf, v1, v2 in confounded:
        g = CausalGraph()
        g.add_confounder(conf, v1, v2)
        scenarios.append({
            'graph': g,
            'question_type': 'causation',
            'X': v1,
            'Y': v2,
            'answer': 'no',
            'explanation': f'{conf} confounds'
        })
    
    # Causal chains
    chains = [
        ("rain", "wet", "slip"),
        ("sleep", "energy", "performance"),
        ("practice", "skill", "success"),
        ("diet", "weight", "health")
    ]
    
    for a, b, c in chains:
        g = CausalGraph()
        g.add_edge(a, b)
        g.add_edge(b, c)
        
        scenarios.append({
            'graph': g,
            'question_type': 'causation',
            'X': a,
            'Y': b,
            'answer': 'yes',
            'explanation': 'Direct edge'
        })
        
        scenarios.append({
            'graph': g,
            'question_type': 'causation',
            'X': a,
            'Y': c,
            'answer': 'yes',
            'explanation': 'Path through b'
        })
    
    return scenarios

# =============================================================================
# GNN MODEL
# =============================================================================

class CausalGNN(nn.Module):
    def __init__(self, node_features=32, hidden_dim=64):
        super().__init__()
        
        self.node_embedding = nn.Embedding(100, node_features)
        
        self.conv1 = GATConv(node_features, hidden_dim, heads=4, concat=True)
        self.conv2 = GATConv(hidden_dim * 4, hidden_dim, heads=4, concat=True)
        self.conv3 = GATConv(hidden_dim * 4, hidden_dim, heads=1, concat=False)
        
        self.combiner = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        
        self.classifier = nn.Linear(64, 2)
    
    def forward(self, node_indices, edge_index, question_nodes):
        x = self.node_embedding(node_indices)
        
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        x_node = x[question_nodes[0]]
        y_node = x[question_nodes[1]]
        
        combined = torch.cat([x_node, y_node])
        features = self.combiner(combined)
        logits = self.classifier(features)
        
        return logits

def scenario_to_graph_data(scenario):
    graph = scenario['graph']
    
    num_nodes = len(graph.nodes)
    node_indices = torch.arange(num_nodes)
    edge_index = graph.to_edge_index()
    
    X_idx = graph.node_to_idx[scenario['X']]
    Y_idx = graph.node_to_idx[scenario['Y']]
    question_nodes = torch.tensor([X_idx, Y_idx])
    
    label = 1 if scenario['answer'] == 'yes' else 0
    
    data = Data(
        x=node_indices,
        edge_index=edge_index,
        question_nodes=question_nodes,
        y=torch.tensor([label]),
        num_nodes=num_nodes
    )
    
    return data

# =============================================================================
# TRAINING
# =============================================================================

def train_causal_gnn(epochs=200):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}\n")
    
    print("Generating scenarios...")
    scenarios = generate_causal_graph_scenarios()
    scenarios = scenarios * 20
    random.shuffle(scenarios)
    
    print(f"Total: {len(scenarios)}\n")
    
    graph_data = [scenario_to_graph_data(s) for s in scenarios]
    
    split = int(0.8 * len(graph_data))
    train_data = graph_data[:split]
    test_data = graph_data[split:]
    
    print(f"Train: {len(train_data)}, Test: {len(test_data)}\n")
    
    model = CausalGNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    print(f"Training {epochs} epochs...\n")
    
    best_acc = 0
    
    for epoch in range(epochs):
        model.train()
        train_correct = 0
        
        random.shuffle(train_data)
        
        for data in tqdm(train_data, desc=f"Epoch {epoch+1}"):
            data = data.to(device)
            
            optimizer.zero_grad()
            logits = model(data.x, data.edge_index, data.question_nodes)
            loss = F.cross_entropy(logits.unsqueeze(0), data.y)
            loss.backward()
            optimizer.step()
            
            pred = logits.argmax().item()
            train_correct += (pred == data.y.item())
        
        train_acc = 100 * train_correct / len(train_data)
        
        # Test
        model.eval()
        test_correct = 0
        
        with torch.no_grad():
            for data in test_data:
                data = data.to(device)
                logits = model(data.x, data.edge_index, data.question_nodes)
                pred = logits.argmax().item()
                test_correct += (pred == data.y.item())
        
        test_acc = 100 * test_correct / len(test_data)
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}: Train: {train_acc:.1f}%, Test: {test_acc:.1f}%")
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'causal_gnn.pth')
    
    print(f"\n✅ Best: {best_acc:.1f}%")
    return model

# =============================================================================
# TESTING
# =============================================================================

def test_causal_gnn():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print("\n" + "="*70)
    print("TESTING CAUSAL GNN")
    print("="*70)
    
    model = CausalGNN().to(device)
    model.load_state_dict(torch.load('causal_gnn.pth'))
    model.eval()
    
    # Test cases
    test_cases = []
    
    # Test 1
    g1 = CausalGraph()
    g1.add_confounder("temperature", "ice_cream", "drowning")
    test_cases.append({
        'name': 'Ice Cream & Drowning (Confounded)',
        'graph': g1,
        'X': 'ice_cream',
        'Y': 'drowning',
        'expected': 'no'
    })
    
    # Test 2
    g2 = CausalGraph()
    g2.add_edge("studying", "grades")
    test_cases.append({
        'name': 'Studying -> Grades (Direct)',
        'graph': g2,
        'X': 'studying',
        'Y': 'grades',
        'expected': 'yes'
    })
    
    # Test 3
    g3 = CausalGraph()
    g3.add_edge("fire", "firefighters")
    g3.add_edge("fire", "damage")
    test_cases.append({
        'name': 'Firefighters & Damage (Common Cause)',
        'graph': g3,
        'X': 'firefighters',
        'Y': 'damage',
        'expected': 'no'
    })
    
    # Test 4
    g4 = CausalGraph()
    g4.add_edge("smoking", "yellow_fingers")
    g4.add_edge("smoking", "cancer")
    test_cases.append({
        'name': 'Yellow Fingers & Cancer (Common Cause)',
        'graph': g4,
        'X': 'yellow_fingers',
        'Y': 'cancer',
        'expected': 'no'
    })
    
    passed = 0
    
    for i, test in enumerate(test_cases, 1):
        print(f"\n{'='*70}")
        print(f"TEST {i}: {test['name']}")
        print(f"{'='*70}")
        
        scenario = {
            'graph': test['graph'],
            'X': test['X'],
            'Y': test['Y'],
            'answer': test['expected']
        }
        
        data = scenario_to_graph_data(scenario).to(device)
        
        with torch.no_grad():
            logits = model(data.x, data.edge_index, data.question_nodes)
            pred = logits.argmax().item()
        
        prediction = 'yes' if pred == 1 else 'no'
        
        print(f"Question: Does {test['X']} cause {test['Y']}?")
        print(f"Expected: {test['expected']}")
        print(f"Predicted: {prediction}")
        
        if prediction == test['expected']:
            print("✅ CORRECT!")
            passed += 1
        else:
            print("❌ WRONG")
    
    print(f"\n{'='*70}")
    print(f"SCORE: {passed}/{len(test_cases)} ({100*passed//len(test_cases)}%)")
    print(f"{'='*70}")
    
    return passed >= 3

def main():
    try:
        import torch_geometric
    except ImportError:
        print("Install: pip install torch-geometric")
        return
    
    model = train_causal_gnn(epochs=200)
    success = test_causal_gnn()
    
    print("\n" + "="*70)
    print("FINAL RESULTS")
    print("="*70)
    
    if success:
        print("\n✅ SUCCESS! Causal GNN working!")
        print("\nCapabilities:")
        print("  1. Distinguish correlation from causation")
        print("  2. Identify confounders")
        print("  3. Reason about causal graphs")
    else:
        print("\n⚠️ Needs more training")

if __name__ == "__main__":
    main()
