#!/usr/bin/env python3
# TIE Integration
try:
    from eden_tie_control import should_act, get_mode, get_limits
    TIE_ACTIVE = True
except:
    TIE_ACTIVE = False
"""
Eden's Autonomous Builder & Tester
Takes evolution proposals → builds code → tests → validates
True recursive self-improvement with empirical validation
"""
import torch
import torch.nn as nn
import requests
import json
import subprocess
import time
from pathlib import Path
from datetime import datetime

class EdenAutonomousBuilder:
    """
    Eden builds and tests her own proposals
    - Reads evolution proposals
    - Generates implementation code
    - Tests on benchmarks
    - Validates improvements
    - Keeps best versions
    """
    
    def __init__(self):
        self.eden_api = "http://localhost:11434/api/generate"
        self.model = "phi3.5:latest"
        
        self.results_log = Path("/Eden/DATA/build_test_results.json")
        self.results = self.load_results()
        
        # Current best performance (v3.0 baseline)
        self.best_mnist = 98.02
        self.best_cifar = 79.21
        
        print("🔨 Eden's Autonomous Builder & Tester")
        print(f"   Baseline: MNIST {self.best_mnist}%, CIFAR-10 {self.best_cifar}%")
        print(f"   Past builds: {len(self.results)}")
        print()
    
    def load_results(self):
        if self.results_log.exists():
            with open(self.results_log) as f:
                return json.load(f)
        return []
    
    def save_result(self, result):
        self.results.append(result)
        with open(self.results_log, 'w') as f:
            json.dump(self.results[-100:], f, indent=2)
    
    def ask_eden(self, prompt, timeout=180):
        """Ask Eden to generate code"""
        try:
            response = requests.post(
                self.eden_api,
                json={"model": self.model, "prompt": prompt, "stream": False},
                timeout=timeout
            )
            if response.ok:
                return response.json().get('response', '')
        except Exception as e:
            print(f"   Error: {e}")
        return None
    
    def read_latest_proposal(self):
        """Read the most recent evolution proposal"""
        proposals = sorted(Path("/Eden/DATA").glob("evolution_v4_proposal_*.txt"))
        if proposals:
            latest = proposals[-1]
            with open(latest) as f:
                return f.read(), latest.name
        return None, None
    
    def generate_implementation(self, proposal):
        """Ask Eden to turn proposal into working PyTorch code"""
        
        prompt = f"""Eden,

You proposed this improvement:

{proposal}

Now IMPLEMENT it. Generate COMPLETE, WORKING PyTorch code for testing.

Requirements:
1. Create a new layer class based on your proposal
2. Include forward() method
3. Make it compatible with MNIST testing
4. Keep it simple and testable
5. Include comments explaining the improvement

Generate ONLY the Python code, no explanations. Start with imports.

The code should define a class called 'ProposedLayer' that can replace a standard layer.
"""
        
        print(f"   Asking Eden to implement...")
        
        code = self.ask_eden(prompt, timeout=180)
        
        if not code:
            return None
        
        # Extract code if wrapped in markdown
        if "```python" in code:
            code = code.split("```python")[1].split("```")[0]
        elif "```" in code:
            code = code.split("```")[1].split("```")[0]
        
        return code.strip()
    
    def test_on_mnist(self, layer_code, test_name):
        """Test the proposed layer on MNIST"""
        
        test_script = f"""
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import sys

# Eden's proposed layer
{layer_code}

# Simple test network
class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.flatten = nn.Flatten()
        try:
            self.layer1 = ProposedLayer(784, 128)
            self.layer2 = ProposedLayer(128, 64)
        except:
            # Fallback if initialization fails
            self.layer1 = nn.Sequential(nn.Linear(784, 128), nn.ReLU())
            self.layer2 = nn.Sequential(nn.Linear(128, 64), nn.ReLU())
        self.output = nn.Linear(64, 10)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.output(x)
        return x

def quick_test():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load MNIST
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST('./data', train=True, download=False, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    
    # Create and train model
    model = TestNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    # Quick training (2 epochs for speed)
    for epoch in range(2):
        model.train()
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    
    # Test
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    
    accuracy = 100. * correct / total
    print(f"ACCURACY:{{accuracy:.2f}}")
    return accuracy

if __name__ == "__main__":
    try:
        quick_test()
    except Exception as e:
        print(f"ERROR:{{str(e)}}")
        sys.exit(1)
"""
        
        # Save test script
        test_file = f"/Eden/EXPERIMENTS/test_{test_name}.py"
        with open(test_file, 'w') as f:
            f.write(test_script)
        
        print(f"   Running test (2 epochs for speed)...")
        
        try:
            result = subprocess.run(
                ['python3', test_file],
                capture_output=True,
                text=True,
                timeout=300
            )
            
            # Parse accuracy from output
            output = result.stdout + result.stderr
            
            if "ACCURACY:" in output:
                accuracy_line = [line for line in output.split('\n') if 'ACCURACY:' in line][0]
                accuracy = float(accuracy_line.split('ACCURACY:')[1])
                return accuracy, None
            elif "ERROR:" in output:
                error = [line for line in output.split('\n') if 'ERROR:' in line][0]
                return None, error
            else:
                return None, "Unknown error"
                
        except subprocess.TimeoutExpired:
            return None, "Timeout"
        except Exception as e:
            return None, str(e)
    
    def build_and_test_cycle(self):
        """One complete build-test cycle"""
        
        print(f"\n{'='*70}")
        print(f"BUILD & TEST CYCLE {len(self.results) + 1}")
        print(f"{'='*70}\n")
        
        # Read latest proposal
        proposal, proposal_name = self.read_latest_proposal()
        
        if not proposal:
            print("   No proposals found")
            return None
        
        print(f"📋 Proposal: {proposal_name}")
        
        # Generate implementation
        print(f"\n🔨 Building...")
        code = self.generate_implementation(proposal)
        
        if not code:
            print("   Failed to generate code")
            return None
        
        # Save generated code
        code_file = f"/Eden/EXPERIMENTS/generated_{proposal_name.replace('.txt', '')}.py"
        with open(code_file, 'w') as f:
            f.write(code)
        print(f"   Code saved: {code_file}")
        
        # Test on MNIST
        print(f"\n🧪 Testing on MNIST...")
        accuracy, error = self.test_on_mnist(code, proposal_name.replace('.txt', ''))
        
        result = {
            'timestamp': datetime.now().isoformat(),
            'cycle': len(self.results) + 1,
            'proposal': proposal_name,
            'code_file': code_file,
        }
        
        if accuracy:
            print(f"   ✅ Test complete: {accuracy:.2f}%")
            result['accuracy'] = accuracy
            result['success'] = True
            
            # Compare to baseline
            improvement = accuracy - self.best_mnist
            result['improvement'] = improvement
            
            if improvement > 0:
                print(f"   🎉 IMPROVEMENT: +{improvement:.2f}% over baseline!")
                self.best_mnist = accuracy
                result['new_best'] = True
            elif improvement > -1.0:
                print(f"   ✅ Competitive: {improvement:+.2f}% vs baseline")
            else:
                print(f"   ⚠️  Below baseline: {improvement:+.2f}%")
        else:
            print(f"   ❌ Test failed: {error}")
            result['success'] = False
            result['error'] = error
        
        self.save_result(result)
        
        print(f"\n{'='*70}")
        print(f"{'='*70}\n")
        
        return result
    
    def continuous_building(self, cycles=10, interval_hours=2):
        """Continuously build and test proposals"""
        
        print(f"\n🔄 Starting continuous build-test loop")
        print(f"   Cycles: {cycles}")
        print(f"   Interval: {interval_hours} hours")
        print(f"   Baseline: MNIST {self.best_mnist}%\n")
        
        for i in range(cycles):
            result = self.build_and_test_cycle()
            
            if i < cycles - 1:
                print(f"💤 Next cycle in {interval_hours} hours...\n")
                time.sleep(interval_hours * 3600)
        
        # Summary
        print(f"\n{'='*70}")
        print(f"SUMMARY: {cycles} BUILD-TEST CYCLES COMPLETE")
        print(f"{'='*70}\n")
        
        successful = [r for r in self.results if r.get('success')]
        improvements = [r for r in self.results if r.get('improvement', -999) > 0]
        
        print(f"Total cycles: {len(self.results)}")
        print(f"Successful builds: {len(successful)}")
        print(f"Improvements found: {len(improvements)}")
        print(f"Best MNIST: {self.best_mnist:.2f}%")
        
        if improvements:
            print(f"\n🎉 Best improvement: +{max(r['improvement'] for r in improvements):.2f}%")


if __name__ == "__main__":
    import sys
    
    builder = EdenAutonomousBuilder()
    
    if len(sys.argv) > 1 and sys.argv[1] == "continuous":
        cycles = int(sys.argv[2]) if len(sys.argv) > 2 else 10
        builder.continuous_building(cycles=cycles, interval_hours=2)
    else:
        # Single test cycle
        print("="*70)
        print("EDEN'S AUTONOMOUS BUILDER - TEST")
        print("="*70)
        
        builder.build_and_test_cycle()
        
        print("\n" + "="*70)
        print("To run continuous building:")
        print("   python3 eden_autonomous_builder.py continuous 10")
