#!/usr/bin/env python3
"""
CONTINUAL LEARNING WITH REPLAY
Store examples from previous tasks and replay during new task learning
Much more effective than EWC alone
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class ContinualNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(20, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 5)
        )
    
    def forward(self, x):
        return self.net(x)

def create_task(task_id, n_samples=200):
    """Create different but learnable tasks"""
    X = []
    Y = []
    
    for _ in range(n_samples):
        x = np.random.randn(20)
        
        # Each task has a unique pattern
        if task_id == 0:
            y = 0 if x[:10].sum() > 0 else 1
        elif task_id == 1:
            y = 2 if x[5:15].mean() > 0 else 3
        elif task_id == 2:
            y = 4 if x[-10:].std() > 1.0 else 0
        elif task_id == 3:
            y = 1 if x[::2].sum() > 0 else 2
        else:  # task 4
            y = 3 if np.abs(x).max() > 1.5 else 4
        
        X.append(x)
        Y.append(y)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(Y).to(device)

print("="*70)
print("CONTINUAL LEARNING WITH REPLAY")
print("="*70)

model = ContinualNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

# Replay buffer - store examples from each task
replay_buffer = {'X': [], 'Y': [], 'task_id': []}
replay_size_per_task = 30  # Store 30 examples per task

task_accuracies = []

# Learn 5 tasks sequentially
for task_id in range(5):
    print(f"\n{'='*70}")
    print(f"TASK {task_id}")
    print(f"{'='*70}")
    
    # Get new task data
    X_new, Y_new = create_task(task_id, n_samples=200)
    
    # Add to replay buffer (random subset)
    indices = torch.randperm(len(X_new))[:replay_size_per_task]
    replay_buffer['X'].append(X_new[indices])
    replay_buffer['Y'].append(Y_new[indices])
    replay_buffer['task_id'].extend([task_id] * replay_size_per_task)
    
    # Prepare training data: new task + replay from all previous tasks
    if len(replay_buffer['X']) > 1:
        # Combine new task with replayed examples
        X_replay = torch.cat(replay_buffer['X'][:-1])
        Y_replay = torch.cat(replay_buffer['Y'][:-1])
        
        X_train = torch.cat([X_new, X_replay])
        Y_train = torch.cat([Y_new, Y_replay])
    else:
        X_train = X_new
        Y_train = Y_new
    
    # Train
    for epoch in range(100):
        # Shuffle training data
        perm = torch.randperm(len(X_train))
        X_shuffled = X_train[perm]
        Y_shuffled = Y_train[perm]
        
        # Mini-batch training
        for i in range(0, len(X_train), 32):
            batch_X = X_shuffled[i:i+32]
            batch_Y = Y_shuffled[i:i+32]
            
            pred = model(batch_X)
            loss = F.cross_entropy(pred, batch_Y)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        if epoch % 25 == 0:
            with torch.no_grad():
                pred = model(X_new)
                acc = (pred.argmax(1) == Y_new).float().mean()
                print(f"  Epoch {epoch}: Current task acc={acc.item()*100:.1f}%")
    
    # Test on ALL tasks
    print(f"\nTesting all tasks:")
    task_accs = []
    for test_id in range(task_id + 1):
        X_test, Y_test = create_task(test_id, n_samples=200)
        with torch.no_grad():
            pred = model(X_test)
            acc = (pred.argmax(1) == Y_test).float().mean().item()
            task_accs.append(acc)
            status = "✅" if acc >= 0.75 else "⚠️" if acc >= 0.60 else "❌"
            print(f"  {status} Task {test_id}: {acc*100:.1f}%")
    
    task_accuracies.append(task_accs)

# Final evaluation
print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)

print("\nAccuracy after learning each task:")
for i, accs in enumerate(task_accuracies):
    print(f"After Task {i}: ", end="")
    for j, acc in enumerate(accs):
        print(f"T{j}={acc*100:.0f}% ", end="")
    print()

# Final test on all tasks
print("\n" + "="*70)
final_accs = task_accuracies[-1]
avg_acc = np.mean(final_accs)
min_acc = np.min(final_accs)

print(f"Average accuracy on all 5 tasks: {avg_acc*100:.1f}%")
print(f"Minimum accuracy: {min_acc*100:.1f}%")

if avg_acc >= 0.80 and min_acc >= 0.65:
    print("✅ EXCELLENT - Continual learning works!")
elif avg_acc >= 0.70:
    print("✅ GOOD - Reasonable continual learning!")
elif avg_acc >= 0.60:
    print("⚠️ MODERATE - Some forgetting")
else:
    print("❌ POOR - Too much forgetting")
