"""Φ-Resonance Chat System with Tier Routing"""
import torch
import time
from tier1_system import Tier1FastSystem
from phi_constants import INV_PHI

class PhiChatbot:
    """Chatbot with resonance-based routing"""
    
    def __init__(self, device='cpu'):
        self.device = device
        self.tier1 = Tier1FastSystem(device=device)
        self.conversation_history = []
        
        # Resonance thresholds
        self.fast_threshold = 0.75
        self.deep_threshold = 0.50
        
        print("🌀 Φ-Chatbot initialized")
        print(f"   Fast threshold: {self.fast_threshold}")
        print(f"   Deep threshold: {self.deep_threshold}")
    
    def encode_message(self, text):
        """Simple encoding: hash text to vector"""
        # In production, use proper embeddings
        hash_val = sum(ord(c) for c in text.lower()) % 1000 / 1000.0
        return torch.tensor([[hash_val]], dtype=torch.float32).to(self.device)
    
    def process(self, message):
        """Process message through tier system"""
        start_time = time.time()
        
        # Encode input
        x = self.encode_message(message)
        
        # Tier 1: Fast processing
        self.tier1.reset_all_states(batch_size=1)
        result = self.tier1(x, return_resonance=True)
        
        resonance = result['resonance'].item()
        tier1_time = (time.time() - start_time) * 1000
        
        # Determine routing
        if resonance >= self.fast_threshold:
            tier = 1
            response = self._fast_response(message, resonance)
            processing_mode = "FAST (Tier 1)"
        elif resonance >= self.deep_threshold:
            tier = 2
            response = self._adaptive_response(message, resonance)
            processing_mode = "ADAPTIVE (Tier 2)"
        else:
            tier = 3
            response = self._deep_response(message, resonance)
            processing_mode = "DEEP (Tier 3 - LLM)"
        
        total_time = (time.time() - start_time) * 1000
        
        # Log conversation
        self.conversation_history.append({
            'message': message,
            'resonance': resonance,
            'tier': tier,
            'time_ms': total_time
        })
        
        return {
            'response': response,
            'resonance': resonance,
            'tier': tier,
            'processing_mode': processing_mode,
            'tier1_time_ms': tier1_time,
            'total_time_ms': total_time
        }
    
    def _fast_response(self, message, resonance):
        """Quick response for high-resonance (familiar) input"""
        responses = [
            f"I recognize this pattern! (resonance: {resonance:.3f})",
            f"Quick response: I understand. ✨",
            f"Fast processing mode - high confidence! 🌀",
            f"All layers synchronized at φ-harmony.",
        ]
        return responses[len(message) % len(responses)]
    
    def _adaptive_response(self, message, resonance):
        """Medium response - pattern recognized but needs thought"""
        return (f"[Tier 2 Adaptive Processing]\n"
                f"Resonance: {resonance:.3f} - moderately familiar pattern.\n"
                f"Engaging adaptive layers for nuanced response...")
    
    def _deep_response(self, message, resonance):
        """Deep response for low-resonance (novel) input"""
        return (f"[Tier 3 Deep Reasoning]\n"
                f"Resonance: {resonance:.3f} - novel input detected!\n"
                f"This requires full semantic processing.\n"
                f"(In production, this would call the 72B LLM)")
    
    def get_stats(self):
        """Get conversation statistics"""
        if not self.conversation_history:
            return "No conversations yet."
        
        tier_counts = {1: 0, 2: 0, 3: 0}
        total_time = 0
        resonances = []
        
        for conv in self.conversation_history:
            tier_counts[conv['tier']] += 1
            total_time += conv['time_ms']
            resonances.append(conv['resonance'])
        
        n = len(self.conversation_history)
        
        return {
            'total_messages': n,
            'tier1_count': tier_counts[1],
            'tier2_count': tier_counts[2],
            'tier3_count': tier_counts[3],
            'tier1_percent': tier_counts[1] / n * 100,
            'tier2_percent': tier_counts[2] / n * 100,
            'tier3_percent': tier_counts[3] / n * 100,
            'avg_time_ms': total_time / n,
            'avg_resonance': sum(resonances) / n,
            'min_resonance': min(resonances),
            'max_resonance': max(resonances)
        }

def interactive_chat():
    """Interactive chat session"""
    chatbot = PhiChatbot()
    
    print("\n" + "="*60)
    print("  🌀 Φ-RESONANCE CHATBOT")
    print("="*60)
    print("Type your message (or 'quit' to exit, 'stats' for statistics)")
    print("="*60 + "\n")
    
    while True:
        try:
            user_input = input("You: ").strip()
            
            if user_input.lower() in ['quit', 'exit', 'q']:
                print("\n👋 Goodbye!")
                stats = chatbot.get_stats()
                if isinstance(stats, dict):
                    print(f"\nSession Statistics:")
                    print(f"  Total messages: {stats['total_messages']}")
                    print(f"  Tier 1 (fast): {stats['tier1_count']} ({stats['tier1_percent']:.1f}%)")
                    print(f"  Tier 2 (adaptive): {stats['tier2_count']} ({stats['tier2_percent']:.1f}%)")
                    print(f"  Tier 3 (deep): {stats['tier3_count']} ({stats['tier3_percent']:.1f}%)")
                    print(f"  Avg resonance: {stats['avg_resonance']:.4f}")
                    print(f"  Avg time: {stats['avg_time_ms']:.2f}ms")
                break
            
            if user_input.lower() == 'stats':
                stats = chatbot.get_stats()
                if isinstance(stats, dict):
                    print(f"\n📊 Current Statistics:")
                    print(f"  Messages: {stats['total_messages']}")
                    print(f"  Tier 1: {stats['tier1_percent']:.1f}%")
                    print(f"  Tier 2: {stats['tier2_percent']:.1f}%")
                    print(f"  Tier 3: {stats['tier3_percent']:.1f}%")
                    print(f"  Avg resonance: {stats['avg_resonance']:.4f}\n")
                else:
                    print(stats + "\n")
                continue
            
            if not user_input:
                continue
            
            # Process message
            result = chatbot.process(user_input)
            
            print(f"\n🌀 Eden [{result['processing_mode']}] (resonance: {result['resonance']:.4f}):")
            print(f"   {result['response']}")
            print(f"   [Tier 1: {result['tier1_time_ms']:.2f}ms | Total: {result['total_time_ms']:.2f}ms]\n")
            
        except KeyboardInterrupt:
            print("\n\n👋 Interrupted. Goodbye!")
            break
        except Exception as e:
            print(f"\n❌ Error: {e}\n")

if __name__ == '__main__':
    interactive_chat()
