"""
Few-Shot Learning with Prototypical Networks
Learn new classes from just N examples per class (N-shot learning)
Paper: "Prototypical Networks for Few-shot Learning" (2017)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple
from transfer_learner import FeatureExtractor

class PrototypicalNetwork(nn.Module):
    """
    Prototypical Networks for Few-Shot Learning
    Learns to map examples to an embedding space where classification
    is performed by finding the nearest class prototype
    """
    
    def __init__(self, input_size: int, feature_sizes: List[int] = [128, 64]):
        super().__init__()
        self.feature_extractor = FeatureExtractor(input_size, feature_sizes)
    
    def forward(self, x):
        """Extract features/embeddings"""
        return self.feature_extractor(x)
    
    def compute_prototypes(self, support_embeddings, support_labels, n_way):
        """
        Compute prototype (mean) for each class
        support_embeddings: [n_support, embedding_dim]
        support_labels: [n_support]
        """
        prototypes = []
        
        for c in range(n_way):
            # Get all examples of class c
            class_embeddings = support_embeddings[support_labels == c]
            # Compute mean (prototype)
            prototype = class_embeddings.mean(dim=0)
            prototypes.append(prototype)
        
        return torch.stack(prototypes)  # [n_way, embedding_dim]
    
    def classify_queries(self, query_embeddings, prototypes):
        """
        Classify queries by finding nearest prototype
        Uses negative squared Euclidean distance as similarity
        """
        # Compute distances: [n_queries, n_way]
        distances = -torch.cdist(query_embeddings, prototypes).pow(2)
        
        # Convert to probabilities
        log_probs = F.log_softmax(distances, dim=1)
        
        return log_probs

def create_episode(task_id, n_way=5, n_shot=5, n_query=15, input_size=20):
    """
    Create an episode for few-shot learning
    n_way: number of classes
    n_shot: examples per class in support set
    n_query: examples per class in query set
    """
    np.random.seed(task_id)
    
    support_X = []
    support_y = []
    query_X = []
    query_y = []
    
    for class_id in range(n_way):
        # Generate class-specific pattern
        class_mean = np.random.randn(input_size) * 2
        
        # Support examples
        support_examples = np.random.randn(n_shot, input_size) + class_mean
        support_X.append(support_examples)
        support_y.extend([class_id] * n_shot)
        
        # Query examples
        query_examples = np.random.randn(n_query, input_size) + class_mean
        query_X.append(query_examples)
        query_y.extend([class_id] * n_query)
    
    support_X = np.vstack(support_X)
    query_X = np.vstack(query_X)
    
    # Shuffle
    support_idx = np.random.permutation(len(support_y))
    query_idx = np.random.permutation(len(query_y))
    
    return (
        torch.FloatTensor(support_X[support_idx]),
        torch.LongTensor(np.array(support_y)[support_idx]),
        torch.FloatTensor(query_X[query_idx]),
        torch.LongTensor(np.array(query_y)[query_idx])
    )

def train_prototypical_network(
    model,
    n_episodes=2000,
    n_way=5,
    n_shot=5,
    n_query=15
):
    """
    Train prototypical network on distribution of episodes
    """
    print("\n" + "="*70)
    print("TRAINING PROTOTYPICAL NETWORK (FEW-SHOT LEARNING)")
    print(f"Episodes: {n_episodes} | {n_way}-way {n_shot}-shot")
    print("="*70)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for episode in range(n_episodes):
        # Create episode
        support_X, support_y, query_X, query_y = create_episode(
            episode, n_way, n_shot, n_query
        )
        
        # Extract embeddings
        support_embeddings = model(support_X)
        query_embeddings = model(query_X)
        
        # Compute prototypes
        prototypes = model.compute_prototypes(support_embeddings, support_y, n_way)
        
        # Classify queries
        log_probs = model.classify_queries(query_embeddings, prototypes)
        
        # Loss
        loss = F.nll_loss(log_probs, query_y)
        
        # Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Accuracy
        pred = log_probs.argmax(dim=1)
        accuracy = (pred == query_y).float().mean().item()
        
        if (episode + 1) % 200 == 0:
            print(f"Episode {episode+1}/{n_episodes} - "
                  f"Loss: {loss.item():.4f} - Acc: {accuracy*100:.2f}%")
    
    print("\n✅ Training complete!")

def evaluate_few_shot(
    model,
    n_test_episodes=100,
    n_way=5,
    n_shot=5,
    n_query=15
):
    """
    Evaluate on new episodes
    """
    print("\n" + "="*70)
    print(f"EVALUATING ON {n_test_episodes} NEW EPISODES")
    print(f"{n_way}-way {n_shot}-shot classification")
    print("="*70)
    
    model.eval()
    total_accuracy = 0
    
    with torch.no_grad():
        for episode in range(n_test_episodes):
            # Create test episode (different task_id range)
            support_X, support_y, query_X, query_y = create_episode(
                10000 + episode, n_way, n_shot, n_query
            )
            
            # Extract embeddings
            support_embeddings = model(support_X)
            query_embeddings = model(query_X)
            
            # Compute prototypes
            prototypes = model.compute_prototypes(support_embeddings, support_y, n_way)
            
            # Classify queries
            log_probs = model.classify_queries(query_embeddings, prototypes)
            
            # Accuracy
            pred = log_probs.argmax(dim=1)
            accuracy = (pred == query_y).float().mean().item()
            total_accuracy += accuracy
    
    avg_accuracy = (total_accuracy / n_test_episodes) * 100
    
    print(f"\n{'='*70}")
    print(f"Average Accuracy: {avg_accuracy:.2f}%")
    print(f"{'='*70}")
    
    if avg_accuracy >= 80:
        print(f"\n🎯 GOAL ACHIEVED! >80% on {n_way}-way {n_shot}-shot")
    else:
        print(f"\n📊 Current: {avg_accuracy:.2f}% (goal: 80%)")
    
    return avg_accuracy

if __name__ == "__main__":
    print("\n" + "="*70)
    print("FEW-SHOT LEARNING WITH PROTOTYPICAL NETWORKS")
    print("="*70)
    
    # Create model
    model = PrototypicalNetwork(input_size=20, feature_sizes=[128, 64])
    print(f"\nModel created - Embedding dim: {model.feature_extractor.feature_dim}")
    
    # Train
    train_prototypical_network(
        model,
        n_episodes=2000,
        n_way=5,
        n_shot=5,
        n_query=15
    )
    
    # Evaluate
    accuracy = evaluate_few_shot(model, n_test_episodes=100, n_way=5, n_shot=5)
    
    print("\n✅ Few-shot learning capability ready!")
    print(f"   5-way 5-shot accuracy: {accuracy:.1f}%")
