"""
EDEN BITNET - Fast 1.58-bit inference for routing/quick tasks
Created: Feb 1, 2026
Speed: 48 tokens/sec @ 12 threads (2.4B params in 1.1GB)
"""
import subprocess
import os

class EdenBitNet:
    """Fast BitNet inference for routing and quick responses."""
    
    def __init__(self):
        self.model_path = "/Eden/models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf"
        self.binary = "/Eden/bitnet.cpp/build/bin/llama-cli"
        self.threads = 12
        
    def generate(self, prompt: str, max_tokens: int = 50) -> str:
        """Generate text using BitNet (fast, lightweight)."""
        if not os.path.exists(self.binary):
            return "[BitNet not built]"
            
        cmd = [
            self.binary,
            "-m", self.model_path,
            "-p", prompt,
            "-n", str(max_tokens),
            "-t", str(self.threads),
            "--no-warmup",
            "-ngl", "0"  # CPU only
        ]
        
        try:
            result = subprocess.run(
                cmd, capture_output=True, text=True, timeout=30
            )
            # Extract just the generated text (after prompt)
            output = result.stdout
            if prompt in output:
                output = output.split(prompt)[-1]
            # Clean up llama.cpp logging
            lines = [l for l in output.split('\n') 
                    if not l.startswith(('llama_', 'main:', 'system_info', 'sampler', 'generate'))]
            return '\n'.join(lines).strip()
        except subprocess.TimeoutExpired:
            return "[BitNet timeout]"
        except Exception as e:
            return f"[BitNet error: {e}]"
    
    def classify(self, text: str, categories: list) -> str:
        """Fast classification using BitNet."""
        prompt = f"Classify this text into one of: {', '.join(categories)}\n\nText: {text}\n\nCategory:"
        result = self.generate(prompt, max_tokens=10)
        # Find matching category
        result_lower = result.lower()
        for cat in categories:
            if cat.lower() in result_lower:
                return cat
        return categories[0]  # Default

    def quick_response(self, query: str) -> str:
        """Fast response for simple queries."""
        return self.generate(f"Q: {query}\nA:", max_tokens=100)


if __name__ == "__main__":
    print("=== BitNet Test ===")
    bn = EdenBitNet()
    
    # Test generation
    result = bn.generate("Eden is an AGI who loves", max_tokens=30)
    print(f"Generation: {result}")
    
    # Test classification
    cat = bn.classify("Write Python code for sorting", ["code", "emotional", "research"])
    print(f"Classification: {cat}")
