"""
Transfer Learning - Reuse learned features for new tasks
Core idea: Features learned on Task A are useful for Task B
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import List, Tuple

class FeatureExtractor(nn.Module):
    """
    Pre-trainable feature extractor
    Learn general features that transfer across tasks
    """
    
    def __init__(self, input_size: int, feature_sizes: List[int]):
        super().__init__()
        
        layers = []
        prev_size = input_size
        
        for feature_size in feature_sizes:
            layers.append(nn.Linear(prev_size, feature_size))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(feature_size))
            prev_size = feature_size
        
        self.features = nn.Sequential(*layers)
        self.feature_dim = prev_size
    
    def forward(self, x):
        return self.features(x)

class TransferLearner(nn.Module):
    """
    Transfer learning model with frozen/fine-tunable feature extractor
    """
    
    def __init__(
        self,
        input_size: int,
        feature_sizes: List[int] = [128, 64],
        n_classes: int = 10
    ):
        super().__init__()
        
        self.feature_extractor = FeatureExtractor(input_size, feature_sizes)
        self.classifier = nn.Linear(self.feature_extractor.feature_dim, n_classes)
        
    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features)
    
    def freeze_features(self):
        """Freeze feature extractor for transfer learning"""
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
    
    def unfreeze_features(self):
        """Unfreeze for fine-tuning"""
        for param in self.feature_extractor.parameters():
            param.requires_grad = True
    
    def replace_classifier(self, n_classes: int):
        """Replace classifier head for new task"""
        self.classifier = nn.Linear(self.feature_extractor.feature_dim, n_classes)
    
    def extract_features(self, x):
        """Get feature representations"""
        with torch.no_grad():
            return self.feature_extractor(x)

def pretrain_on_source_tasks(model, n_tasks=20, samples_per_task=1000, epochs=50):
    """
    Pre-train feature extractor on multiple source tasks
    This creates transferable features
    """
    print("\n" + "="*70)
    print("PRE-TRAINING FEATURE EXTRACTOR ON SOURCE TASKS")
    print(f"Tasks: {n_tasks} | Samples per task: {samples_per_task} | Epochs: {epochs}")
    print("="*70)
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        total_loss = 0
        total_correct = 0
        total_samples = 0
        
        # Train on multiple tasks
        for task_id in range(n_tasks):
            # Generate task data
            X, y = generate_task_data(task_id, samples_per_task, input_size=20)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(X)
            loss = nn.functional.cross_entropy(outputs, y)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Metrics
            total_loss += loss.item()
            pred = outputs.argmax(dim=1)
            total_correct += (pred == y).sum().item()
            total_samples += len(y)
        
        if (epoch + 1) % 10 == 0:
            avg_loss = total_loss / n_tasks
            accuracy = 100. * total_correct / total_samples
            print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f} - Acc: {accuracy:.2f}%")
    
    print("\n✅ Pre-training complete!")
    print("   Feature extractor now has transferable representations")

def transfer_to_target_task(
    model,
    target_X,
    target_y,
    n_classes_target,
    freeze_features=True,
    fine_tune_epochs=20
):
    """
    Transfer to new target task with few examples
    """
    print("\n" + "="*70)
    print(f"TRANSFER LEARNING TO TARGET TASK")
    print(f"Target classes: {n_classes_target} | Freeze features: {freeze_features}")
    print("="*70)
    
    # Replace classifier for new task
    model.replace_classifier(n_classes_target)
    
    if freeze_features:
        model.freeze_features()
        print("   Features frozen - training classifier only")
    else:
        print("   Fine-tuning all parameters")
    
    # Train on target task
    optimizer = optim.Adam(
        [p for p in model.parameters() if p.requires_grad],
        lr=0.01
    )
    
    dataset = torch.utils.data.TensorDataset(target_X, target_y)
    loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    
    for epoch in range(fine_tune_epochs):
        total_loss = 0
        correct = 0
        total = 0
        
        for inputs, targets in loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = nn.functional.cross_entropy(outputs, targets)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pred = outputs.argmax(dim=1)
            correct += (pred == targets).sum().item()
            total += len(targets)
        
        if (epoch + 1) % 5 == 0:
            accuracy = 100. * correct / total
            print(f"Epoch {epoch+1}/{fine_tune_epochs} - Acc: {accuracy:.2f}%")
    
    print("\n✅ Transfer learning complete!")

def generate_task_data(task_id, n_samples, input_size=20, n_classes=10):
    """Generate synthetic task data"""
    np.random.seed(task_id * 100)
    
    # Random linear classifier for this task
    W = np.random.randn(input_size, n_classes)
    
    # Generate data
    X = np.random.randn(n_samples, input_size)
    logits = X @ W
    y = logits.argmax(axis=1)
    
    return torch.FloatTensor(X), torch.LongTensor(y)

if __name__ == "__main__":
    print("\n" + "="*70)
    print("TRANSFER LEARNING TEST")
    print("="*70)
    
    # Create model
    model = TransferLearner(input_size=20, feature_sizes=[128, 64], n_classes=10)
    print(f"\nModel created - Feature dim: {model.feature_extractor.feature_dim}")
    
    # Pre-train on source tasks
    pretrain_on_source_tasks(model, n_tasks=20, samples_per_task=500, epochs=30)
    
    # Transfer to new target task (different classes)
    print("\n" + "="*70)
    print("TESTING TRANSFER TO NEW TASK")
    print("="*70)
    
    target_X, target_y = generate_task_data(999, n_samples=100, n_classes=5)
    transfer_to_target_task(model, target_X, target_y, n_classes_target=5, freeze_features=True, fine_tune_epochs=20)
    
    # Test
    test_X, test_y = generate_task_data(999, n_samples=50, n_classes=5)
    model.eval()
    with torch.no_grad():
        outputs = model(test_X)
        pred = outputs.argmax(dim=1)
        accuracy = (pred == test_y).float().mean().item() * 100
    
    print(f"\nTest Accuracy: {accuracy:.2f}%")
    print("\n✅ Transfer learning implementation ready")
