#!/usr/bin/env python3
"""
CONTINUAL LEARNING
Learn multiple tasks sequentially without catastrophic forgetting
Using Elastic Weight Consolidation (EWC)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")

class ContinualNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(20, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.head = nn.Linear(64, 5)
    
    def forward(self, x):
        features = self.shared(x)
        return self.head(features)

def create_task(task_id):
    """Create different tasks that share structure but differ in patterns"""
    X = []
    Y = []
    
    for _ in range(100):
        x = np.random.randn(20)
        
        # Different task = different pattern
        if task_id == 0:
            # Task 0: Sum of first 5 features
            y = int(x[:5].sum() > 0)
        elif task_id == 1:
            # Task 1: Product of middle features
            y = int(x[7:12].prod() > 0)
        elif task_id == 2:
            # Task 2: Max of last 5 features
            y = int(x[-5:].max() > 0)
        elif task_id == 3:
            # Task 3: Variance of features
            y = int(x.var() > 1.0)
        else:
            # Task 4: Mean absolute value
            y = int(np.abs(x).mean() > 0.8)
        
        # Make it 5-class by modulating
        y = (y + task_id) % 5
        
        X.append(x)
        Y.append(y)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(Y).to(device)

def compute_fisher(model, task_data):
    """Compute Fisher Information Matrix for EWC"""
    model.eval()
    X, Y = task_data
    
    fisher = {}
    for name, param in model.named_parameters():
        fisher[name] = torch.zeros_like(param)
    
    for i in range(len(X)):
        model.zero_grad()
        output = model(X[i:i+1])
        loss = F.cross_entropy(output, Y[i:i+1])
        loss.backward()
        
        for name, param in model.named_parameters():
            if param.grad is not None:
                fisher[name] += param.grad.data ** 2
    
    for name in fisher:
        fisher[name] /= len(X)
    
    return fisher

def ewc_loss(model, old_params, fisher, lambda_ewc=1000):
    """Elastic Weight Consolidation loss"""
    loss = 0
    for name, param in model.named_parameters():
        if name in old_params:
            loss += (fisher[name] * (param - old_params[name]) ** 2).sum()
    return lambda_ewc * loss

print("="*70)
print("CONTINUAL LEARNING TEST")
print("="*70)

model = ContinualNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

# Storage for old tasks
old_params = None
fisher_info = None
task_accuracies = []

# Learn 5 tasks sequentially
for task_id in range(5):
    print(f"\n{'='*70}")
    print(f"LEARNING TASK {task_id}")
    print(f"{'='*70}")
    
    # Create task data
    X, Y = create_task(task_id)
    
    # Train on this task
    for epoch in range(100):
        pred = model(X)
        loss = F.cross_entropy(pred, Y)
        
        # Add EWC penalty if we have previous tasks
        if old_params is not None and fisher_info is not None:
            loss += ewc_loss(model, old_params, fisher_info, lambda_ewc=1000)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        if epoch % 25 == 0:
            acc = (pred.argmax(1) == Y).float().mean()
            print(f"  Epoch {epoch}: Loss={loss.item():.3f}, Acc={acc.item()*100:.1f}%")
    
    # Compute Fisher for this task
    fisher_info = compute_fisher(model, (X, Y))
    
    # Save parameters
    old_params = {name: param.clone() for name, param in model.named_parameters()}
    
    print(f"\n✅ Task {task_id} learned!")
    
    # Test on ALL previous tasks
    print(f"\nTesting on all tasks learned so far:")
    task_accs = []
    for test_id in range(task_id + 1):
        X_test, Y_test = create_task(test_id)
        with torch.no_grad():
            pred = model(X_test)
            acc = (pred.argmax(1) == Y_test).float().mean().item()
            task_accs.append(acc)
            print(f"  Task {test_id}: {acc*100:.1f}%")
    
    task_accuracies.append(task_accs)

# Final evaluation
print("\n" + "="*70)
print("FINAL CONTINUAL LEARNING EVALUATION")
print("="*70)

print("\nAccuracy matrix (rows=after learning task N, cols=task performance):")
print("     ", end="")
for i in range(5):
    print(f"T{i}    ", end="")
print()

for i, accs in enumerate(task_accuracies):
    print(f"After T{i}: ", end="")
    for acc in accs:
        print(f"{acc*100:5.1f}% ", end="")
    print()

# Check for catastrophic forgetting
print("\n" + "="*70)
print("FORGETTING ANALYSIS")
print("="*70)

avg_final_acc = np.mean([accs[-1] for accs in task_accuracies[1:]])
print(f"Average accuracy on all tasks after learning all 5: {avg_final_acc*100:.1f}%")

if avg_final_acc >= 0.75:
    print("✅ EXCELLENT - Minimal catastrophic forgetting!")
elif avg_final_acc >= 0.60:
    print("✅ GOOD - Reasonable retention!")
elif avg_final_acc >= 0.45:
    print("⚠️ MODERATE - Some forgetting occurs")
else:
    print("❌ POOR - Significant catastrophic forgetting")
