#!/usr/bin/env python3
"""
CONTINUAL LEARNING - MULTI-HEAD
Each task gets its own output head, shared backbone
This is the most practical approach for continual learning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class MultiHeadNet(nn.Module):
    def __init__(self, n_tasks=5):
        super().__init__()
        # Shared feature extractor
        self.backbone = nn.Sequential(
            nn.Linear(20, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Separate head for each task
        self.heads = nn.ModuleList([
            nn.Linear(128, 5) for _ in range(n_tasks)
        ])
    
    def forward(self, x, task_id):
        features = self.backbone(x)
        return self.heads[task_id](features)

def create_task(task_id, n_samples=200):
    """Create distinct tasks"""
    X = []
    Y = []
    
    for _ in range(n_samples):
        x = np.random.randn(20)
        
        if task_id == 0:
            y = 0 if x[:10].sum() > 0 else 1
        elif task_id == 1:
            y = 2 if x[5:15].mean() > 0 else 3
        elif task_id == 2:
            y = 4 if x[-10:].std() > 1.0 else 0
        elif task_id == 3:
            y = 1 if x[::2].sum() > 0 else 2
        else:
            y = 3 if np.abs(x).max() > 1.5 else 4
        
        X.append(x)
        Y.append(y)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(Y).to(device)

print("="*70)
print("CONTINUAL LEARNING - MULTI-HEAD ARCHITECTURE")
print("="*70)

model = MultiHeadNet(n_tasks=5).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

# Store data from all tasks
all_tasks = []
task_accuracies = []

# Learn 5 tasks sequentially
for task_id in range(5):
    print(f"\n{'='*70}")
    print(f"LEARNING TASK {task_id}")
    print(f"{'='*70}")
    
    # Get task data
    X, Y = create_task(task_id, n_samples=200)
    all_tasks.append((X, Y, task_id))
    
    # Train on ALL tasks seen so far (with replay)
    for epoch in range(100):
        total_loss = 0
        
        # Train on each task seen so far
        for X_t, Y_t, tid in all_tasks:
            pred = model(X_t, tid)
            loss = F.cross_entropy(pred, Y_t)
            total_loss += loss
        
        opt.zero_grad()
        total_loss.backward()
        opt.step()
        
        if epoch % 25 == 0:
            # Check current task accuracy
            with torch.no_grad():
                pred = model(X, task_id)
                acc = (pred.argmax(1) == Y).float().mean()
                print(f"  Epoch {epoch}: Task {task_id} acc={acc.item()*100:.1f}%")
    
    # Test on ALL tasks learned so far
    print(f"\nTesting all tasks:")
    task_accs = []
    for tid in range(task_id + 1):
        X_test, Y_test = create_task(tid, n_samples=200)
        with torch.no_grad():
            pred = model(X_test, tid)
            acc = (pred.argmax(1) == Y_test).float().mean().item()
            task_accs.append(acc)
            status = "✅" if acc >= 0.80 else "⚠️" if acc >= 0.70 else "❌"
            print(f"  {status} Task {tid}: {acc*100:.1f}%")
    
    task_accuracies.append(task_accs)

# Final evaluation
print("\n" + "="*70)
print("FINAL CONTINUAL LEARNING RESULTS")
print("="*70)

final_accs = task_accuracies[-1]
avg_acc = np.mean(final_accs)
min_acc = np.min(final_accs)

print(f"\nFinal accuracy on all 5 tasks:")
for tid, acc in enumerate(final_accs):
    print(f"  Task {tid}: {acc*100:.1f}%")

print(f"\nAverage: {avg_acc*100:.1f}%")
print(f"Minimum: {min_acc*100:.1f}%")

if avg_acc >= 0.85 and min_acc >= 0.75:
    print("✅ EXCELLENT - Continual learning WORKS!")
elif avg_acc >= 0.75:
    print("✅ GOOD - Effective continual learning!")
elif avg_acc >= 0.65:
    print("⚠️ MODERATE - Some degradation")
else:
    print("❌ Still struggling")

# Save if good
if avg_acc >= 0.75:
    torch.save(model.state_dict(), 'continual_model.pth')
    print("💾 Model saved!")
