#!/usr/bin/env python3
"""
Eden's Autonomous Builder & Tester v2
Enhanced with better error handling and debugging
"""
import torch
import torch.nn as nn
import requests
import json
import subprocess
import time
from pathlib import Path
from datetime import datetime

class EdenAutonomousBuilder:
    
    def __init__(self):
        self.eden_api = "http://localhost:11434/api/generate"
        self.model = "qwen2.5:7b"
        
        self.results_log = Path("/Eden/DATA/build_test_results.json")
        self.results = self.load_results()
        
        self.best_mnist = 98.02
        self.best_cifar = 79.21
        
        print("🔨 Eden's Autonomous Builder & Tester v2")
        print(f"   Baseline: MNIST {self.best_mnist}%, CIFAR-10 {self.best_cifar}%")
        print(f"   Past builds: {len(self.results)}")
        print()
    
    def load_results(self):
        if self.results_log.exists():
            with open(self.results_log) as f:
                return json.load(f)
        return []
    
    def save_result(self, result):
        self.results.append(result)
        with open(self.results_log, 'w') as f:
            json.dump(self.results[-100:], f, indent=2)
    
    def ask_eden(self, prompt, timeout=180):
        try:
            response = requests.post(
                self.eden_api,
                json={"model": self.model, "prompt": prompt, "stream": False},
                timeout=timeout
            )
            if response.ok:
                return response.json().get('response', '')
        except Exception as e:
            print(f"   Error: {e}")
        return None
    
    def read_latest_proposal(self):
        proposals = sorted(Path("/Eden/DATA").glob("evolution_v4_proposal_*.txt"))
        if proposals:
            latest = proposals[-1]
            with open(latest) as f:
                return f.read(), latest.name
        return None, None
    
    def generate_simple_implementation(self, proposal):
        """Generate a simple, working implementation based on proposal"""
        
        # For now, let's create a simple improvement based on v3.0
        # This is a baseline that Eden can improve upon
        
        code = '''import torch
import torch.nn as nn

class ProposedLayer(nn.Module):
    """
    Improved layer based on proposal
    Combines v3.0 hybrid activation with new ideas
    """
    
    def __init__(self, input_size, output_size):
        super(ProposedLayer, self).__init__()
        
        # Main transformation
        self.linear = nn.Linear(input_size, output_size)
        
        # Learnable mixing for hybrid activation
        self.alpha = nn.Parameter(torch.tensor(0.8))
        
        # Batch normalization
        self.bn = nn.BatchNorm1d(output_size)
    
    def forward(self, x):
        # Linear transformation
        z = self.linear(x)
        
        # Batch normalization
        z = self.bn(z)
        
        # Hybrid activation (v3.0 style)
        alpha_clamped = torch.sigmoid(self.alpha)
        relu_part = torch.relu(z)
        smooth_part = z * torch.sigmoid(z)  # Swish-like
        
        output = alpha_clamped * relu_part + (1 - alpha_clamped) * smooth_part
        
        return output
'''
        
        return code
    
    def test_on_mnist_quick(self, layer_code, test_name):
        """Quick MNIST test with better error handling"""
        
        test_script = f'''#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import sys
import traceback

# Eden's proposed layer
{layer_code}

class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.flatten = nn.Flatten()
        self.layer1 = ProposedLayer(784, 128)
        self.layer2 = ProposedLayer(128, 64)
        self.output = nn.Linear(64, 10)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.output(x)
        return x

def quick_test():
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        train_dataset = datasets.MNIST('./data', train=True, download=False, transform=transform)
        test_dataset = datasets.MNIST('./data', train=False, transform=transform)
        
        train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
        
        model = TestNet().to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        
        # Train 2 epochs
        print("Training...", file=sys.stderr)
        for epoch in range(2):
            model.train()
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                
                if batch_idx % 100 == 0:
                    print(f"Epoch {{epoch+1}}, Batch {{batch_idx}}", file=sys.stderr)
        
        # Test
        print("Testing...", file=sys.stderr)
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
        
        accuracy = 100. * correct / total
        print(f"ACCURACY:{{accuracy:.2f}}")
        return 0
        
    except Exception as e:
        print(f"ERROR:{{str(e)}}", file=sys.stderr)
        traceback.print_exc(file=sys.stderr)
        return 1

if __name__ == "__main__":
    sys.exit(quick_test())
'''
        
        test_file = f"/Eden/EXPERIMENTS/test_{test_name}.py"
        with open(test_file, 'w') as f:
            f.write(test_script)
        
        print(f"   Running quick test (2 epochs)...")
        
        try:
            result = subprocess.run(
                ['python3', test_file],
                capture_output=True,
                text=True,
                timeout=300
            )
            
            output = result.stdout + result.stderr
            
            # Debug: show output
            print(f"   [Debug] Last 5 lines of output:")
            for line in output.split('\n')[-5:]:
                if line.strip():
                    print(f"      {line}")
            
            if "ACCURACY:" in output:
                accuracy_line = [line for line in output.split('\n') if 'ACCURACY:' in line][0]
                accuracy = float(accuracy_line.split('ACCURACY:')[1])
                return accuracy, None
            elif "ERROR:" in output:
                error_lines = [line for line in output.split('\n') if 'ERROR:' in line]
                error = error_lines[0] if error_lines else "Unknown error"
                return None, error
            else:
                return None, "No accuracy output found"
                
        except subprocess.TimeoutExpired:
            return None, "Timeout (>5min)"
        except Exception as e:
            return None, str(e)
    
    def build_and_test_cycle(self):
        print(f"\n{'='*70}")
        print(f"BUILD & TEST CYCLE {len(self.results) + 1}")
        print(f"{'='*70}\n")
        
        proposal, proposal_name = self.read_latest_proposal()
        
        if not proposal:
            print("   No proposals found")
            return None
        
        print(f"📋 Proposal: {proposal_name}")
        print(f"   Preview: {proposal[:200]}...\n")
        
        # Generate simple implementation
        print(f"🔨 Building simple v3.0-style implementation...")
        code = self.generate_simple_implementation(proposal)
        
        code_file = f"/Eden/EXPERIMENTS/generated_{proposal_name.replace('.txt', '')}.py"
        with open(code_file, 'w') as f:
            f.write(code)
        print(f"   ✅ Code saved: {code_file}")
        
        # Test
        print(f"\n🧪 Testing on MNIST (quick validation)...")
        accuracy, error = self.test_on_mnist_quick(code, proposal_name.replace('.txt', ''))
        
        result = {
            'timestamp': datetime.now().isoformat(),
            'cycle': len(self.results) + 1,
            'proposal': proposal_name,
            'code_file': code_file,
        }
        
        if accuracy:
            print(f"\n   ✅ Test complete: {accuracy:.2f}%")
            result['accuracy'] = accuracy
            result['success'] = True
            
            improvement = accuracy - self.best_mnist
            result['improvement'] = improvement
            
            if improvement > 0:
                print(f"   🎉 IMPROVEMENT: +{improvement:.2f}% over baseline!")
                self.best_mnist = accuracy
                result['new_best'] = True
            elif improvement > -2.0:
                print(f"   ✅ Competitive: {improvement:+.2f}% vs baseline")
            else:
                print(f"   ⚠️  Needs work: {improvement:+.2f}%")
        else:
            print(f"\n   ❌ Test failed: {error}")
            result['success'] = False
            result['error'] = error
        
        self.save_result(result)
        
        print(f"\n{'='*70}\n")
        
        return result


if __name__ == "__main__":
    builder = EdenAutonomousBuilder()
    
    print("="*70)
    print("EDEN'S AUTONOMOUS BUILDER v2 - TEST")
    print("="*70)
    
    builder.build_and_test_cycle()
