#!/usr/bin/env python3

class SimpleCapabilityTest:
    def __init__(self, system):
        self.system = system
        self.results = []
    
    def test_few_shot_learning(self):
        prompt = """Learn the pattern from these examples:
Example 1: Input: [1, 2, 3] → Output: 6
Example 2: Input: [2, 4, 6] → Output: 12
Example 3: Input: [1, 1, 1] → Output: 3

Now apply the pattern:
Input: [5, 5, 5] → Output: ?

Just give the number, no explanation."""
        
        print("\n" + "="*70)
        print("TEST 1: FEW-SHOT LEARNING")
        print("="*70)
        response = self.system.respond(prompt)
        passed = '15' in response
        
        print(f"Response: {response[:100]}...")
        print(f"Expected: 15")
        print(f"Result: {'✅ PASS' if passed else '❌ FAIL'}")
        
        self.results.append({'test': 'Few-Shot Learning', 'passed': passed})
        return passed
    
    def test_causal_reasoning(self):
        prompt = """Observation: Ice cream sales and drowning deaths are correlated.

Question: Does ice cream cause drowning?
Answer with ONLY: YES, NO, or CONFOUNDED"""
        
        print("\n" + "="*70)
        print("TEST 2: CAUSAL REASONING")
        print("="*70)
        response = self.system.respond(prompt)
        passed = 'no' in response.lower() or 'confound' in response.lower()
        
        print(f"Response: {response[:100]}...")
        print(f"Expected: NO or CONFOUNDED")
        print(f"Result: {'✅ PASS' if passed else '❌ FAIL'}")
        
        self.results.append({'test': 'Causal Reasoning', 'passed': passed})
        return passed
    
    def test_theory_of_mind(self):
        prompt = """Sally puts a marble in basket A.
Sally leaves the room.
Anne moves the marble from basket A to basket B.
Sally returns.

Where will Sally look for the marble?
Answer with ONE letter: A or B"""
        
        print("\n" + "="*70)
        print("TEST 3: THEORY OF MIND")
        print("="*70)
        response = self.system.respond(prompt)
        passed = response.strip().upper().startswith('A')
        
        print(f"Response: {response[:100]}...")
        print(f"Expected: A")
        print(f"Result: {'✅ PASS' if passed else '❌ FAIL'}")
        
        self.results.append({'test': 'Theory of Mind', 'passed': passed})
        return passed
    
    def run_all_tests(self):
        print("\n" + "="*70)
        print("CAPABILITY TEST SUITE")
        print("="*70)
        
        self.test_few_shot_learning()
        self.test_causal_reasoning()
        self.test_theory_of_mind()
        
        total = len(self.results)
        passed = sum(1 for r in self.results if r['passed'])
        percentage = (passed / total) * 100
        
        print("\n" + "="*70)
        print("FINAL RESULTS")
        print("="*70)
        print(f"\nScore: {passed}/{total} ({percentage:.0f}%)")
        
        if percentage >= 70:
            print("\n✅ GOOD baseline capabilities")
        elif percentage >= 50:
            print("\n⚠️ MIXED - some capabilities work, others need improvement")
        else:
            print("\n❌ WEAK baseline - needs significant improvement")
        
        print("\nNext step: Train MAML to improve few-shot learning from ~30% to ~60%")
        
        return self.results
