#!/usr/bin/env python3
"""
EDEN-SOAR INTEGRATION
=====================
Self-improving evolutionary program synthesis using Ollama.
Based on: https://github.com/flowersteam/SOAR

Core loop:
1. SAMPLE: Generate candidate programs
2. REFINE: Improve promising candidates  
3. LEARN: Fine-tune on successes AND failures (hindsight relabeling)
"""
import sys
sys.path.insert(0, '/Eden/SOAR')
sys.path.insert(0, '/Eden/CORE')

import json
import sqlite3
from openai import OpenAI
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from datetime import datetime
import ast
import traceback

# Ollama OpenAI-compatible client
OLLAMA_CLIENT = OpenAI(
    base_url="http://localhost:11434/v1",
    api_key="ollama"  # Ollama doesn't need real key
)

# Default model - Eden has qwen2.5:14b trained models
DEFAULT_MODEL = "qwen2.5:14b"

@dataclass
class SynthesisAttempt:
    """Record of a program synthesis attempt"""
    task_id: str
    program: str
    success: bool
    output: Any
    expected: Any
    timestamp: str
    
@dataclass
class RefinementResult:
    """Result of refining a program"""
    original: str
    refined: str
    improvement: bool
    feedback: str

class EdenSOAR:
    """
    Self-improving program synthesis for Eden.
    Implements the SOAR virtuous cycle:
    Sample → Execute → Refine → Learn
    """
    
    def __init__(self, model: str = DEFAULT_MODEL, db_path: str = "/Eden/DATA/soar_evolution.db"):
        self.model = model
        self.client = OLLAMA_CLIENT
        self.db_path = db_path
        self._init_db()
        print(f"[EDEN-SOAR] Initialized with {model}")
        
    def _init_db(self):
        """Initialize database for storing synthesis attempts"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        c.execute('''CREATE TABLE IF NOT EXISTS synthesis_attempts (
            id INTEGER PRIMARY KEY,
            task_id TEXT,
            task_description TEXT,
            program TEXT,
            success INTEGER,
            output TEXT,
            expected TEXT,
            error TEXT,
            timestamp TEXT
        )''')
        c.execute('''CREATE TABLE IF NOT EXISTS refinements (
            id INTEGER PRIMARY KEY,
            original_id INTEGER,
            refined_program TEXT,
            improvement INTEGER,
            feedback TEXT,
            timestamp TEXT,
            FOREIGN KEY (original_id) REFERENCES synthesis_attempts(id)
        )''')
        c.execute('''CREATE TABLE IF NOT EXISTS hindsight_relabels (
            id INTEGER PRIMARY KEY,
            original_task TEXT,
            failed_program TEXT,
            synthetic_task TEXT,
            synthetic_input TEXT,
            synthetic_output TEXT,
            timestamp TEXT
        )''')
        conn.commit()
        conn.close()
        
    def sample_program(self, task: str, examples: List[Dict] = None, n_samples: int = 1) -> List[str]:
        """
        SAMPLE PHASE: Generate candidate programs for a task.
        """
        prompt = self._build_sample_prompt(task, examples)
        
        programs = []
        for _ in range(n_samples):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": "You are an expert Python programmer. Generate a function that solves the given task. Output ONLY the Python code, no explanations."},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0.8,  # Higher for diversity
                    max_tokens=2000
                )
                code = self._extract_code(response.choices[0].message.content)
                if code:
                    programs.append(code)
            except Exception as e:
                print(f"[SAMPLE] Error: {e}")
                
        return programs
    
    def execute_program(self, program: str, test_input: Any) -> tuple[bool, Any, Optional[str]]:
        """
        Execute a program safely and return (success, output, error)
        """
        try:
            # Create isolated namespace
            namespace = {}
            exec(program, namespace)
            
            # Find the main function (usually 'solve' or 'transform')
            func = None
            for name in ['solve', 'transform', 'main', 'solution']:
                if name in namespace and callable(namespace[name]):
                    func = namespace[name]
                    break
            
            if func is None:
                # Try to find any callable
                for name, obj in namespace.items():
                    if callable(obj) and not name.startswith('_'):
                        func = obj
                        break
                        
            if func is None:
                return False, None, "No callable function found"
                
            output = func(test_input)
            return True, output, None
            
        except Exception as e:
            return False, None, str(e)
    
    def refine_program(self, program: str, task: str, error: str = None, 
                       actual_output: Any = None, expected_output: Any = None) -> str:
        """
        REFINE PHASE: Improve a program based on feedback.
        """
        feedback_parts = []
        if error:
            feedback_parts.append(f"Error: {error}")
        if actual_output is not None and expected_output is not None:
            feedback_parts.append(f"Got: {actual_output}")
            feedback_parts.append(f"Expected: {expected_output}")
            
        feedback = "\n".join(feedback_parts) if feedback_parts else "Program works but may be improvable"
        
        prompt = f"""Task: {task}

Current program:
```python
{program}
```

Feedback:
{feedback}

Please fix/improve the program. Output ONLY the corrected Python code."""

        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "You are an expert Python debugger. Fix the code based on the feedback. Output ONLY Python code."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.3,  # Lower for focused refinement
                max_tokens=2000
            )
            refined = self._extract_code(response.choices[0].message.content)
            return refined if refined else program
        except Exception as e:
            print(f"[REFINE] Error: {e}")
            return program
    
    def hindsight_relabel(self, program: str, original_task: str, 
                          actual_input: Any, actual_output: Any) -> Optional[Dict]:
        """
        HINDSIGHT LEARNING: A failed program is correct for SOME task.
        Create a synthetic task that this program DOES solve.
        """
        # This program failed the original task, but it produced SOME output
        # That output is the correct answer for a synthetic task
        
        synthetic_task = {
            "original_task": original_task,
            "program": program,
            "synthetic_input": actual_input,
            "synthetic_output": actual_output,
            "description": f"Transform input to produce: {actual_output}"
        }
        
        # Store for training
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        c.execute('''INSERT INTO hindsight_relabels 
                     (original_task, failed_program, synthetic_task, synthetic_input, synthetic_output, timestamp)
                     VALUES (?, ?, ?, ?, ?, ?)''',
                  (original_task, program, json.dumps(synthetic_task), 
                   json.dumps(actual_input), json.dumps(actual_output), 
                   datetime.now().isoformat()))
        conn.commit()
        conn.close()
        
        return synthetic_task
    
    def evolve(self, task: str, examples: List[Dict], 
               max_iterations: int = 5, samples_per_iteration: int = 3) -> Optional[str]:
        """
        FULL SOAR LOOP: Evolve a solution through sample → refine → learn
        """
        print(f"[SOAR] Evolving solution for: {task[:50]}...")
        
        best_program = None
        best_score = 0
        
        for iteration in range(max_iterations):
            print(f"[SOAR] Iteration {iteration + 1}/{max_iterations}")
            
            # SAMPLE
            if best_program:
                # Refine best so far
                programs = [self.refine_program(best_program, task)]
                # Also sample new candidates
                programs.extend(self.sample_program(task, examples, samples_per_iteration - 1))
            else:
                programs = self.sample_program(task, examples, samples_per_iteration)
            
            # EXECUTE & SCORE
            for program in programs:
                score = 0
                all_correct = True
                
                for ex in examples:
                    success, output, error = self.execute_program(program, ex['input'])
                    
                    if success and output == ex['output']:
                        score += 1
                    else:
                        all_correct = False
                        # Hindsight relabel - this program solves SOME task
                        if success and output is not None:
                            self.hindsight_relabel(program, task, ex['input'], output)
                
                # Record attempt
                self._record_attempt(task, program, all_correct, examples)
                
                if score > best_score:
                    best_score = score
                    best_program = program
                    print(f"[SOAR] New best: {score}/{len(examples)} examples correct")
                
                if all_correct:
                    print(f"[SOAR] ✓ Solution found in iteration {iteration + 1}")
                    return program
        
        print(f"[SOAR] Best achieved: {best_score}/{len(examples)} examples")
        return best_program
    
    def _build_sample_prompt(self, task: str, examples: List[Dict] = None) -> str:
        """Build prompt for program sampling"""
        prompt = f"Task: {task}\n\n"
        
        if examples:
            prompt += "Examples:\n"
            for i, ex in enumerate(examples):
                prompt += f"Input {i+1}: {ex['input']}\n"
                prompt += f"Output {i+1}: {ex['output']}\n\n"
        
        prompt += "Write a Python function called 'solve' that takes an input and returns the correct output."
        return prompt
    
    def _extract_code(self, text: str) -> Optional[str]:
        """Extract Python code from LLM response"""
        # Try to find code block
        if "```python" in text:
            start = text.find("```python") + 9
            end = text.find("```", start)
            if end > start:
                return text[start:end].strip()
        elif "```" in text:
            start = text.find("```") + 3
            end = text.find("```", start)
            if end > start:
                return text[start:end].strip()
        
        # Check if entire response is code
        try:
            ast.parse(text)
            return text.strip()
        except:
            pass
            
        return None
    
    def _record_attempt(self, task: str, program: str, success: bool, examples: List[Dict]):
        """Record synthesis attempt to database"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        c.execute('''INSERT INTO synthesis_attempts 
                     (task_id, task_description, program, success, expected, timestamp)
                     VALUES (?, ?, ?, ?, ?, ?)''',
                  (task[:50], task, program, int(success), 
                   json.dumps([e['output'] for e in examples]),
                   datetime.now().isoformat()))
        conn.commit()
        conn.close()
    
    def get_training_data(self) -> List[Dict]:
        """
        Get all data for training (successes + hindsight relabels)
        This is what makes SOAR self-improving!
        """
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        
        training_data = []
        
        # Successful programs
        c.execute("SELECT task_description, program FROM synthesis_attempts WHERE success = 1")
        for task, program in c.fetchall():
            training_data.append({
                "type": "success",
                "task": task,
                "program": program
            })
        
        # Hindsight relabeled (failed programs that solve synthetic tasks)
        c.execute("SELECT synthetic_task, failed_program FROM hindsight_relabels")
        for task_json, program in c.fetchall():
            task = json.loads(task_json)
            training_data.append({
                "type": "hindsight",
                "task": task['description'],
                "program": program,
                "input": task['synthetic_input'],
                "output": task['synthetic_output']
            })
        
        conn.close()
        return training_data


# ═══════════════════════════════════════════════════════════════
# QUICK TEST
# ═══════════════════════════════════════════════════════════════

if __name__ == "__main__":
    soar = EdenSOAR()
    
    # Test task: simple transformation
    task = "Double each number in the input list"
    examples = [
        {"input": [1, 2, 3], "output": [2, 4, 6]},
        {"input": [5, 10], "output": [10, 20]},
        {"input": [0, -1, 7], "output": [0, -2, 14]}
    ]
    
    print("="*60)
    print("EDEN-SOAR TEST: Self-Improving Program Synthesis")
    print("="*60)
    
    solution = soar.evolve(task, examples, max_iterations=3, samples_per_iteration=2)
    
    if solution:
        print("\n" + "="*60)
        print("FINAL SOLUTION:")
        print("="*60)
        print(solution)
        
        # Test it
        print("\nVerification:")
        for ex in examples:
            success, output, error = soar.execute_program(solution, ex['input'])
            status = "✓" if output == ex['output'] else "✗"
            print(f"  {status} {ex['input']} → {output} (expected {ex['output']})")
    
    # Show training data collected
    print("\n" + "="*60)
    print("TRAINING DATA COLLECTED (for self-improvement):")
    print("="*60)
    training = soar.get_training_data()
    print(f"Total samples: {len(training)}")
    for t in training[:5]:
        print(f"  [{t['type']}] {t['task'][:50]}...")
