"""
Neural Architecture Search (NAS)
Automatically design neural network architectures
"""
import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple
import time

class SearchSpace:
    """NAS search space - possible architectures"""
    
    def __init__(self):
        self.layer_types = ['linear', 'conv']
        self.activations = ['relu', 'tanh']
        self.layer_sizes = [32, 64, 128]
    
    def sample_architecture(self, n_layers: int = 3) -> List[dict]:
        """Sample random architecture"""
        arch = []
        for _ in range(n_layers):
            layer = {
                'type': np.random.choice(self.layer_types),
                'size': np.random.choice(self.layer_sizes),
                'activation': np.random.choice(self.activations)
            }
            arch.append(layer)
        return arch

def build_model(architecture: List[dict], input_size: int, output_size: int) -> nn.Module:
    """Build PyTorch model from architecture specification"""
    layers = []
    prev_size = input_size
    
    for layer_spec in architecture:
        # Add layer
        layers.append(nn.Linear(prev_size, layer_spec['size']))
        
        # Add activation
        if layer_spec['activation'] == 'relu':
            layers.append(nn.ReLU())
        elif layer_spec['activation'] == 'tanh':
            layers.append(nn.Tanh())
        
        prev_size = layer_spec['size']
    
    # Output layer
    layers.append(nn.Linear(prev_size, output_size))
    
    return nn.Sequential(*layers)

def evaluate_architecture(
    architecture: List[dict],
    train_data: Tuple[torch.Tensor, torch.Tensor],
    epochs: int = 10
) -> float:
    """Train and evaluate architecture"""
    X_train, y_train = train_data
    
    model = build_model(architecture, input_size=X_train.shape[1], output_size=2)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    # Quick training
    for _ in range(epochs):
        optimizer.zero_grad()
        outputs = model(X_train)
        loss = nn.functional.cross_entropy(outputs, y_train)
        loss.backward()
        optimizer.step()
    
    # Evaluate
    with torch.no_grad():
        outputs = model(X_train)
        pred = outputs.argmax(dim=1)
        accuracy = (pred == y_train).float().mean().item()
    
    return accuracy

def random_search_nas(
    train_data: Tuple[torch.Tensor, torch.Tensor],
    n_trials: int = 20
) -> Tuple[List[dict], float]:
    """Random search over architectures"""
    print("\n" + "="*70)
    print(f"NEURAL ARCHITECTURE SEARCH (Random Search)")
    print(f"Trials: {n_trials}")
    print("="*70)
    
    search_space = SearchSpace()
    
    best_arch = None
    best_acc = 0
    
    results = []
    
    for trial in range(n_trials):
        # Sample architecture
        arch = search_space.sample_architecture(n_layers=3)
        
        # Evaluate
        start = time.time()
        acc = evaluate_architecture(arch, train_data, epochs=10)
        elapsed = time.time() - start
        
        results.append((arch, acc))
        
        if acc > best_acc:
            best_acc = acc
            best_arch = arch
        
        if (trial + 1) % 5 == 0:
            print(f"Trial {trial+1}/{n_trials} - Best acc so far: {best_acc*100:.2f}%")
    
    print(f"\n{'='*70}")
    print("BEST ARCHITECTURE FOUND")
    print(f"{'='*70}")
    print(f"Accuracy: {best_acc*100:.2f}%")
    print("\nArchitecture:")
    for i, layer in enumerate(best_arch):
        print(f"  Layer {i+1}: {layer['size']} units, {layer['activation']}")
    
    return best_arch, best_acc

def evolutionary_nas(
    train_data: Tuple[torch.Tensor, torch.Tensor],
    population_size: int = 10,
    generations: int = 5
) -> Tuple[List[dict], float]:
    """Evolutionary search over architectures"""
    print("\n" + "="*70)
    print(f"NEURAL ARCHITECTURE SEARCH (Evolutionary)")
    print(f"Population: {population_size} | Generations: {generations}")
    print("="*70)
    
    search_space = SearchSpace()
    
    # Initialize population
    population = [search_space.sample_architecture(n_layers=3) for _ in range(population_size)]
    
    best_overall = None
    best_acc_overall = 0
    
    for gen in range(generations):
        # Evaluate population
        fitness = []
        for arch in population:
            acc = evaluate_architecture(arch, train_data, epochs=10)
            fitness.append(acc)
            
            if acc > best_acc_overall:
                best_acc_overall = acc
                best_overall = arch
        
        print(f"Generation {gen+1}/{generations} - Best: {max(fitness)*100:.2f}% - Avg: {np.mean(fitness)*100:.2f}%")
        
        # Selection: keep top 50%
        sorted_pop = sorted(zip(population, fitness), key=lambda x: x[1], reverse=True)
        survivors = [arch for arch, _ in sorted_pop[:population_size//2]]
        
        # Reproduction: mutate survivors to create new population
        population = survivors.copy()
        while len(population) < population_size:
            parent = survivors[np.random.randint(len(survivors))]
            child = mutate_architecture(parent, search_space)
            population.append(child)
    
    print(f"\n{'='*70}")
    print("BEST ARCHITECTURE (EVOLUTIONARY)")
    print(f"{'='*70}")
    print(f"Accuracy: {best_acc_overall*100:.2f}%")
    
    return best_overall, best_acc_overall

def mutate_architecture(arch: List[dict], search_space: SearchSpace) -> List[dict]:
    """Mutate architecture"""
    new_arch = [layer.copy() for layer in arch]
    
    # Randomly mutate one layer
    idx = np.random.randint(len(new_arch))
    mutation_type = np.random.choice(['size', 'activation'])
    
    if mutation_type == 'size':
        new_arch[idx]['size'] = np.random.choice(search_space.layer_sizes)
    else:
        new_arch[idx]['activation'] = np.random.choice(search_space.activations)
    
    return new_arch

if __name__ == "__main__":
    print("\n" + "="*70)
    print("NEURAL ARCHITECTURE SEARCH")
    print("="*70)
    
    # Generate dummy data
    np.random.seed(42)
    X = torch.randn(500, 20)
    y = (X.sum(dim=1) > 0).long()
    train_data = (X, y)
    
    # Random search
    arch_random, acc_random = random_search_nas(train_data, n_trials=20)
    
    # Evolutionary search
    arch_evo, acc_evo = evolutionary_nas(train_data, population_size=10, generations=5)
    
    print("\n" + "="*70)
    print("NAS RESULTS")
    print("="*70)
    print(f"Random Search:      {acc_random*100:.2f}% accuracy")
    print(f"Evolutionary:       {acc_evo*100:.2f}% accuracy")
    
    best_method = "Evolutionary" if acc_evo > acc_random else "Random Search"
    print(f"\nBest method: {best_method}")
    
    print("\n✅ Neural Architecture Search implementation complete!")
    print("   Used in: AutoML, efficient deep learning")
