"""
MAML Implementation - Model-Agnostic Meta-Learning
Learns to learn: finds initialization that enables fast adaptation
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Dict
import copy

class MAMLModel(nn.Module):
    """Simple neural network for MAML"""
    
    def __init__(self, input_size: int, hidden_sizes: List[int], output_size: int):
        super().__init__()
        
        layers = []
        prev_size = input_size
        
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.ReLU())
            prev_size = hidden_size
        
        layers.append(nn.Linear(prev_size, output_size))
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)
    
    def get_flat_params(self):
        """Get flattened parameters"""
        return torch.cat([p.view(-1) for p in self.parameters()])
    
    def set_flat_params(self, flat_params):
        """Set parameters from flattened vector"""
        offset = 0
        for p in self.parameters():
            numel = p.numel()
            p.data.copy_(flat_params[offset:offset + numel].view_as(p))
            offset += numel

class MAML:
    """
    Model-Agnostic Meta-Learning
    Learns initialization that enables fast adaptation to new tasks
    """
    
    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.01,
        outer_lr: float = 0.001,
        inner_steps: int = 5,
        first_order: bool = False
    ):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.inner_steps = inner_steps
        self.first_order = first_order  # First-order MAML (faster but less accurate)
        
        self.meta_optimizer = optim.Adam(self.model.parameters(), lr=outer_lr)
        
    def inner_loop(self, support_x, support_y):
        """
        Inner loop: Adapt to a specific task
        Returns adapted parameters
        """
        # Clone model for task-specific adaptation
        adapted_params = [p.clone() for p in self.model.parameters()]
        
        for step in range(self.inner_steps):
            # Forward pass with current adapted params
            logits = self._forward_with_params(support_x, adapted_params)
            loss = F.cross_entropy(logits, support_y)
            
            # Compute gradients
            grads = torch.autograd.grad(
                loss,
                adapted_params,
                create_graph=not self.first_order,  # Second-order if not first_order
                allow_unused=True
            )
            
            # Update adapted parameters
            adapted_params = [
                p - self.inner_lr * g if g is not None else p
                for p, g in zip(adapted_params, grads)
            ]
        
        return adapted_params
    
    def _forward_with_params(self, x, params):
        """Forward pass using specific parameters"""
        # Manually apply parameters to network
        x = x
        param_idx = 0
        
        for module in self.model.network:
            if isinstance(module, nn.Linear):
                weight = params[param_idx]
                bias = params[param_idx + 1]
                x = F.linear(x, weight, bias)
                param_idx += 2
            elif isinstance(module, nn.ReLU):
                x = F.relu(x)
        
        return x
    
    def meta_train_step(self, task_batch):
        """
        One meta-training step on a batch of tasks
        
        task_batch: List of (support_x, support_y, query_x, query_y)
        """
        self.meta_optimizer.zero_grad()
        
        total_loss = 0
        total_acc = 0
        
        for support_x, support_y, query_x, query_y in task_batch:
            # Inner loop: adapt to task
            adapted_params = self.inner_loop(support_x, support_y)
            
            # Evaluate adapted model on query set
            query_logits = self._forward_with_params(query_x, adapted_params)
            query_loss = F.cross_entropy(query_logits, query_y)
            
            # Accumulate loss for meta-update
            total_loss += query_loss
            
            # Compute accuracy
            pred = query_logits.argmax(dim=1)
            acc = (pred == query_y).float().mean()
            total_acc += acc
        
        # Meta-update (outer loop)
        avg_loss = total_loss / len(task_batch)
        avg_acc = total_acc / len(task_batch)
        
        avg_loss.backward()
        self.meta_optimizer.step()
        
        return avg_loss.item(), avg_acc.item()
    
    def adapt_to_task(self, support_x, support_y):
        """
        Adapt model to a new task (for evaluation)
        Returns adapted model
        """
        adapted_model = copy.deepcopy(self.model)
        optimizer = optim.SGD(adapted_model.parameters(), lr=self.inner_lr)
        
        for step in range(self.inner_steps):
            optimizer.zero_grad()
            logits = adapted_model(support_x)
            loss = F.cross_entropy(logits, support_y)
            loss.backward()
            optimizer.step()
        
        return adapted_model
    
    def evaluate_task(self, support_x, support_y, query_x, query_y):
        """
        Evaluate on a single task
        Returns accuracy before and after adaptation
        """
        # Before adaptation
        self.model.eval()
        with torch.no_grad():
            logits_before = self.model(query_x)
            pred_before = logits_before.argmax(dim=1)
            acc_before = (pred_before == query_y).float().mean().item()
        
        # Adapt to task
        adapted_model = self.adapt_to_task(support_x, support_y)
        
        # After adaptation
        adapted_model.eval()
        with torch.no_grad():
            logits_after = adapted_model(query_x)
            pred_after = logits_after.argmax(dim=1)
            acc_after = (pred_after == query_y).float().mean().item()
        
        return acc_before, acc_after


def create_sinusoid_task(amplitude=None, phase=None, n_samples=10):
    """
    Create a sinusoid regression task
    Classic MAML benchmark: y = A * sin(x + φ)
    """
    if amplitude is None:
        amplitude = np.random.uniform(0.1, 5.0)
    if phase is None:
        phase = np.random.uniform(0, np.pi)
    
    x = np.random.uniform(-5, 5, (n_samples, 1))
    y = amplitude * np.sin(x + phase)
    
    return torch.FloatTensor(x), torch.FloatTensor(y)


def create_classification_task(task_id, n_samples=20, n_features=10):
    """
    Create a binary classification task
    Different tasks have different decision boundaries
    """
    np.random.seed(task_id)
    
    # Generate random decision boundary
    w = np.random.randn(n_features)
    w = w / np.linalg.norm(w)
    b = np.random.randn()
    
    # Generate data
    X = np.random.randn(n_samples, n_features)
    y = (X @ w + b > 0).astype(int)
    
    # Split into support and query
    support_x = torch.FloatTensor(X[:n_samples//2])
    support_y = torch.LongTensor(y[:n_samples//2])
    query_x = torch.FloatTensor(X[n_samples//2:])
    query_y = torch.LongTensor(y[n_samples//2:])
    
    return support_x, support_y, query_x, query_y


if __name__ == "__main__":
    print("\n" + "="*70)
    print("MAML: QUICK TEST")
    print("="*70)
    
    # Create model
    model = MAMLModel(input_size=10, hidden_sizes=[40, 40], output_size=2)
    maml = MAML(model, inner_lr=0.01, outer_lr=0.001, inner_steps=5)
    
    print(f"\nModel Parameters: {sum(p.numel() for p in model.parameters())}")
    print(f"Inner LR: {maml.inner_lr}")
    print(f"Outer LR: {maml.outer_lr}")
    print(f"Inner Steps: {maml.inner_steps}")
    
    # Test on one task
    print("\n" + "="*70)
    print("Testing inner loop adaptation...")
    print("="*70)
    
    support_x, support_y, query_x, query_y = create_classification_task(0, n_samples=20)
    
    acc_before, acc_after = maml.evaluate_task(support_x, support_y, query_x, query_y)
    
    print(f"\nAccuracy before adaptation: {acc_before*100:.1f}%")
    print(f"Accuracy after {maml.inner_steps} steps: {acc_after*100:.1f}%")
    print(f"Improvement: {(acc_after - acc_before)*100:.1f}%")
    
    print("\n✅ MAML implementation ready")
    print("   Next: Meta-training loop")
