#!/usr/bin/env python3
"""
Continual Learning V3 - Stronger anti-forgetting
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
import copy

class ContinualModelV3(nn.Module):
    def __init__(self, input_dim=10, hidden_dim=128, output_dim=2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class ContinualLearnerV3:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Initializing continual learner V3...")
        self.model = ContinualModelV3().to(self.device)
        self.task_data = defaultdict(list)
        print("✅ Continual learning V3 ready!")
    
    def learn_task(self, data, labels, task_name):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005)
        
        print(f"\nLearning task: {task_name}")
        
        # Store ALL data for this task
        self.task_data[task_name] = [(data[i].detach().clone(), labels[i].detach().clone()) 
                                       for i in range(len(data))]
        
        # Training with BALANCED replay
        for epoch in range(100):
            # Current task
            pred = self.model(data)
            new_loss = F.cross_entropy(pred, labels)
            
            # Replay ALL previous tasks
            replay_losses = []
            for prev_task, prev_data in self.task_data.items():
                if prev_task != task_name and len(prev_data) > 0:
                    # Sample from previous task
                    sample_size = min(len(data), len(prev_data))
                    import random
                    samples = random.sample(prev_data, sample_size)
                    
                    replay_x = torch.stack([s[0] for s in samples])
                    replay_y = torch.stack([s[1] for s in samples])
                    
                    replay_pred = self.model(replay_x)
                    replay_losses.append(F.cross_entropy(replay_pred, replay_y))
            
            # STRONG replay: 20% new, 80% old
            if replay_losses:
                total_replay = sum(replay_losses) / len(replay_losses)
                loss = 0.2 * new_loss + 0.8 * total_replay
            else:
                loss = new_loss
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
            optimizer.step()
        
        print(f"✅ Learned {task_name}")
    
    def evaluate(self, data, labels):
        with torch.no_grad():
            pred = self.model(data)
            acc = (pred.argmax(dim=1) == labels).float().mean().item()
        return acc

def test_continual_v3():
    print("\n" + "="*70)
    print("TESTING CONTINUAL LEARNING V3")
    print("="*70)
    
    cl = ContinualLearnerV3()
    
    # Task 1
    task1_data = torch.randn(50, 10).to(cl.device)
    task1_labels = torch.zeros(50, dtype=torch.long).to(cl.device)
    
    cl.learn_task(task1_data, task1_labels, "Task 1")
    acc1_initial = cl.evaluate(task1_data, task1_labels)
    print(f"Task 1 initial: {acc1_initial*100:.1f}%")
    
    # Task 2
    task2_data = torch.randn(50, 10).to(cl.device)
    task2_labels = torch.ones(50, dtype=torch.long).to(cl.device)
    
    cl.learn_task(task2_data, task2_labels, "Task 2")
    
    # Evaluate
    acc1_after = cl.evaluate(task1_data, task1_labels)
    acc2 = cl.evaluate(task2_data, task2_labels)
    
    print(f"\nAfter learning Task 2:")
    print(f"Task 1 accuracy: {acc1_after*100:.1f}%")
    print(f"Task 2 accuracy: {acc2*100:.1f}%")
    print(f"Forgetting: {(acc1_initial - acc1_after)*100:.1f}%")
    
    if acc1_after >= 0.85 and acc2 >= 0.85:
        print("\n✅ EXCELLENT - Minimal forgetting!")
        return True
    elif acc1_after >= 0.75 and acc2 >= 0.75:
        print("\n✅ GOOD - Reduced forgetting!")
        return True
    else:
        print("\n⚠️ Still some forgetting")
        return False

def main():
    if test_continual_v3():
        print("\n✅ CAPABILITY #15 COMPLETE (IMPROVED)")

if __name__ == "__main__":
    main()
