"""
Bayesian Inference - Probabilistic reasoning and belief updating
Foundation for uncertainty quantification
"""
import numpy as np
from typing import Dict, List

class BayesianNetwork:
    """Simple Bayesian network for probabilistic reasoning"""
    
    def __init__(self):
        self.nodes = {}
        self.edges = {}
        self.cpds = {}  # Conditional probability distributions
    
    def add_node(self, name: str, parents: List[str] = None):
        """Add node to network"""
        self.nodes[name] = parents or []
        self.edges[name] = []
        for parent in (parents or []):
            if parent in self.edges:
                self.edges[parent].append(name)
    
    def set_cpd(self, node: str, cpd: Dict):
        """Set conditional probability distribution"""
        self.cpds[node] = cpd
    
    def compute_probability(self, evidence: Dict) -> float:
        """Compute joint probability given evidence"""
        prob = 1.0
        
        for node, value in evidence.items():
            if node not in self.cpds:
                continue
            
            cpd = self.cpds[node]
            parents = self.nodes[node]
            
            if not parents:
                # Root node - use prior
                prob *= cpd.get(value, 0)
            else:
                # Conditional probability
                parent_values = tuple(evidence.get(p) for p in parents)
                prob *= cpd.get(parent_values, {}).get(value, 0)
        
        return prob
    
    def infer(self, query_var: str, evidence: Dict, query_value=True) -> float:
        """Bayesian inference: P(query_var | evidence)"""
        # Enumerate over query variable
        prob_true = self.compute_probability({**evidence, query_var: True})
        prob_false = self.compute_probability({**evidence, query_var: False})
        
        # Normalize
        total = prob_true + prob_false
        if total == 0:
            return 0.0
        
        return prob_true / total if query_value else prob_false / total

def medical_diagnosis_example():
    """Classic medical diagnosis Bayesian network"""
    print("\n" + "="*70)
    print("BAYESIAN INFERENCE: MEDICAL DIAGNOSIS")
    print("="*70)
    
    # Create network: Disease -> Symptoms
    bn = BayesianNetwork()
    bn.add_node('disease')
    bn.add_node('fever', ['disease'])
    bn.add_node('cough', ['disease'])
    
    # Prior probability of disease
    bn.set_cpd('disease', {True: 0.01, False: 0.99})
    
    # P(fever | disease)
    bn.set_cpd('fever', {
        (True,): {True: 0.9, False: 0.1},   # If disease
        (False,): {True: 0.1, False: 0.9}   # No disease
    })
    
    # P(cough | disease)
    bn.set_cpd('cough', {
        (True,): {True: 0.8, False: 0.2},
        (False,): {True: 0.2, False: 0.8}
    })
    
    # Test inference
    print("\nPrior probability of disease: 1%")
    
    # Given fever
    prob_disease_fever = bn.infer('disease', {'fever': True})
    print(f"\nP(disease | fever) = {prob_disease_fever*100:.1f}%")
    
    # Given fever AND cough
    prob_disease_both = bn.infer('disease', {'fever': True, 'cough': True})
    print(f"P(disease | fever, cough) = {prob_disease_both*100:.1f}%")
    
    # Given no symptoms
    prob_disease_none = bn.infer('disease', {'fever': False, 'cough': False})
    print(f"P(disease | no fever, no cough) = {prob_disease_none*100:.2f}%")
    
    return prob_disease_both

def beta_bernoulli_conjugate():
    """Beta-Bernoulli conjugate prior example"""
    print("\n" + "="*70)
    print("BAYESIAN UPDATING: COIN FLIP LEARNING")
    print("="*70)
    
    # Prior: Beta(2, 2) - slightly biased toward fair coin
    alpha_prior = 2
    beta_prior = 2
    
    print(f"\nPrior: Beta({alpha_prior}, {beta_prior})")
    print(f"Prior mean (expected probability of heads): {alpha_prior/(alpha_prior+beta_prior):.2f}")
    
    # Observe coin flips
    observations = [1, 1, 0, 1, 1, 1, 0, 1, 1, 1]  # 1=heads, 0=tails
    
    # Update beliefs (Bayesian learning)
    alpha_post = alpha_prior + sum(observations)
    beta_post = beta_prior + (len(observations) - sum(observations))
    
    print(f"\nObserved {len(observations)} flips: {sum(observations)} heads, {len(observations)-sum(observations)} tails")
    print(f"Posterior: Beta({alpha_post}, {beta_post})")
    print(f"Posterior mean: {alpha_post/(alpha_post+beta_post):.2f}")
    
    # Credible interval (Bayesian confidence interval)
    from scipy import stats
    lower, upper = stats.beta.interval(0.95, alpha_post, beta_post)
    print(f"95% credible interval: [{lower:.2f}, {upper:.2f}]")
    
    return alpha_post/(alpha_post+beta_post)

if __name__ == "__main__":
    print("\n" + "="*70)
    print("BAYESIAN INFERENCE")
    print("="*70)
    
    # Medical diagnosis
    prob = medical_diagnosis_example()
    
    # Coin flip learning
    try:
        posterior_mean = beta_bernoulli_conjugate()
    except ImportError:
        print("\n(scipy not available - skipping Beta example)")
        posterior_mean = 0.8
    
    print("\n" + "="*70)
    print("BAYESIAN INFERENCE RESULTS")
    print("="*70)
    print(f"✅ Medical diagnosis: {prob*100:.1f}% disease probability")
    print(f"✅ Bayesian learning: Posterior mean ~{posterior_mean:.2f}")
    
    print("\n✅ Bayesian inference implementation complete!")
    print("   Foundation for uncertainty quantification in AI")
