"""
Relational Representation Engine
Converts flat embeddings into structured knowledge graphs
"""

import networkx as nx
import numpy as np
from typing import List, Dict, Optional
from sentence_transformers import SentenceTransformer
import spacy
import yaml
import logging
import os
from pathlib import Path

logger = logging.getLogger(__name__)

class RelationalEncoder:
    """Builds typed relational graphs from text/embeddings"""
    
    def __init__(self, config_path: str = "/Eden/CONFIG/phi_fractal_config.yaml"):
        with open(config_path) as f:
            config = yaml.safe_load(f)
        
        self.config = config['relational_encoder']
        self.relation_types = self.config['relation_types']
        self.confidence_threshold = self.config['confidence_threshold']
        self.max_nodes = self.config['max_graph_nodes']
        self.save_interval = self.config['save_interval']
        
        try:
            self.nlp = spacy.load("en_core_web_sm")
            logger.info("✓ spaCy model loaded")
        except OSError:
            logger.error("spaCy model not found. Run: python -m spacy download en_core_web_sm")
            self.nlp = None
        
        model_name = self.config['embedding_model']
        self.embedding_model = SentenceTransformer(model_name)
        logger.info(f"✓ Embedding model loaded: {model_name}")
        
        self.graph = nx.DiGraph()
        self.update_count = 0
        
        self.graph_path = Path("/Eden/MEMORY/graphs")
        self.graph_path.mkdir(parents=True, exist_ok=True)
        
        self.load_graph()
        
        logger.info(f"RelationalEncoder initialized: {len(self.graph.nodes)} nodes")
    
    def encode_text(self, text: str, session_id: str) -> nx.DiGraph:
        """Convert text into relational graph"""
        if not self.nlp:
            logger.error("spaCy not available")
            return self.graph
        
        if not text or len(text.strip()) == 0:
            return self.graph
        
        doc = self.nlp(text)
        entities = self._extract_entities(doc)
        
        if not entities:
            logger.debug(f"No entities found in: {text[:50]}")
            return self.graph
        
        relations = self._extract_relations(doc, entities)
        
        for entity in entities:
            self._add_node(entity, session_id)
        
        for rel in relations:
            self._add_edge(rel, session_id)
        
        self.update_count += 1
        if self.update_count % self.save_interval == 0:
            self.save_graph()
            logger.info(f"Graph auto-saved at update {self.update_count}")
        
        return self.graph
    
    def _extract_entities(self, doc) -> List[Dict]:
        """Extract entities from spaCy doc"""
        entities = []
        seen = set()
        
        for ent in doc.ents:
            if ent.text.lower() not in seen:
                entities.append({
                    'text': ent.text,
                    'label': ent.label_,
                    'start': ent.start_char,
                    'end': ent.end_char,
                    'embedding': self.embedding_model.encode(ent.text).tolist()
                })
                seen.add(ent.text.lower())
        
        for chunk in doc.noun_chunks:
            if chunk.text.lower() not in seen and chunk.root.pos_ in ['NOUN', 'PROPN']:
                entities.append({
                    'text': chunk.text,
                    'label': 'CONCEPT',
                    'start': chunk.start_char,
                    'end': chunk.end_char,
                    'embedding': self.embedding_model.encode(chunk.text).tolist()
                })
                seen.add(chunk.text.lower())
        
        return entities
    
    def _extract_relations(self, doc, entities: List[Dict]) -> List[Dict]:
        """Extract typed relations between entities"""
        relations = []
        
        for token in doc:
            if token.lemma_ in ['cause', 'lead', 'result', 'produce', 'create', 'trigger']:
                subj = self._find_entity_by_dep(token, entities, 'nsubj')
                obj = self._find_entity_by_dep(token, entities, 'dobj')
                if subj and obj:
                    relations.append({
                        'source': subj['text'],
                        'target': obj['text'],
                        'type': 'CAUSES',
                        'confidence': 0.8
                    })
            
            elif token.lemma_ == 'be' and token.pos_ == 'AUX':
                subj = self._find_entity_by_dep(token, entities, 'nsubj')
                obj = self._find_entity_by_dep(token, entities, 'attr')
                if subj and obj:
                    relations.append({
                        'source': subj['text'],
                        'target': obj['text'],
                        'type': 'IS_A',
                        'confidence': 0.9
                    })
        
        return relations
    
    def _find_entity_by_dep(self, token, entities: List[Dict], dep: str) -> Optional[Dict]:
        """Find entity by dependency relation"""
        for child in token.children:
            if child.dep_ == dep:
                return self._find_entity_by_token(child, entities)
        return None
    
    def _find_entity_by_token(self, token, entities: List[Dict]) -> Optional[Dict]:
        """Find entity that contains token"""
        for entity in entities:
            if entity['start'] <= token.idx < entity['end']:
                return entity
        return None
    
    def _add_node(self, entity: Dict, session_id: str):
        """Add node to graph with attributes"""
        node_id = entity['text'].lower().strip()
        
        if not self.graph.has_node(node_id):
            self.graph.add_node(
                node_id,
                label=entity['text'],
                entity_type=entity['label'],
                embedding=entity['embedding'],
                sessions=[session_id],
                created=self.update_count
            )
        else:
            sessions = self.graph.nodes[node_id].get('sessions', [])
            if session_id not in sessions:
                sessions.append(session_id)
                self.graph.nodes[node_id]['sessions'] = sessions
    
    def _add_edge(self, relation: Dict, session_id: str):
        """Add typed edge to graph"""
        if relation['confidence'] < self.confidence_threshold:
            return
        
        source_id = relation['source'].lower().strip()
        target_id = relation['target'].lower().strip()
        
        if self.graph.has_node(source_id) and self.graph.has_node(target_id):
            if source_id == target_id:
                return
            
            if self.graph.has_edge(source_id, target_id):
                current_conf = self.graph[source_id][target_id].get('confidence', 0.5)
                new_conf = min(0.95, current_conf + 0.05)
                self.graph[source_id][target_id]['confidence'] = new_conf
            else:
                self.graph.add_edge(
                    source_id,
                    target_id,
                    relation_type=relation['type'],
                    confidence=relation['confidence'],
                    session=session_id
                )
    
    def find_similar_nodes(self, query: str, top_k: int = 5) -> List:
        """Find nodes most similar to query text"""
        if len(self.graph) == 0:
            return []
        
        query_emb = self.embedding_model.encode(query)
        similarities = []
        
        for node_id in self.graph.nodes():
            node_emb = np.array(self.graph.nodes[node_id]['embedding'])
            sim = np.dot(query_emb, node_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(node_emb))
            similarities.append((node_id, float(sim)))
        
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]
    
    def save_graph(self, path: Optional[str] = None):
        """Save graph to disk"""
        if path is None:
            path = self.graph_path / "current.graphml"
        
        if len(self.graph) == 0:
            return
        
        graph_copy = self.graph.copy()
        for node in graph_copy.nodes():
            if 'embedding' in graph_copy.nodes[node]:
                emb = graph_copy.nodes[node]['embedding']
                graph_copy.nodes[node]['embedding_str'] = ','.join(map(str, emb[:10]))
                del graph_copy.nodes[node]['embedding']
            if 'sessions' in graph_copy.nodes[node]:
                graph_copy.nodes[node]['sessions_str'] = ','.join(graph_copy.nodes[node]['sessions'][:5])
                del graph_copy.nodes[node]['sessions']
        
        nx.write_graphml(graph_copy, path)
        logger.info(f"Graph saved: {len(self.graph.nodes)} nodes, {len(self.graph.edges)} edges")
    
    def load_graph(self, path: Optional[str] = None):
        """Load graph from disk"""
        if path is None:
            path = self.graph_path / "current.graphml"
        
        if not os.path.exists(path):
            logger.info("No existing graph, starting fresh")
            return
        
        try:
            graph_loaded = nx.read_graphml(path)
            
            for node in graph_loaded.nodes():
                if 'embedding_str' in graph_loaded.nodes[node]:
                    label = graph_loaded.nodes[node].get('label', node)
                    graph_loaded.nodes[node]['embedding'] = self.embedding_model.encode(label).tolist()
                    del graph_loaded.nodes[node]['embedding_str']
                
                if 'sessions_str' in graph_loaded.nodes[node]:
                    graph_loaded.nodes[node]['sessions'] = graph_loaded.nodes[node]['sessions_str'].split(',')
                    del graph_loaded.nodes[node]['sessions_str']
            
            self.graph = graph_loaded
            logger.info(f"Graph loaded: {len(self.graph.nodes)} nodes")
        except Exception as e:
            logger.error(f"Failed to load graph: {e}")
    
    
    def get_subgraph(self, node_id: str, depth: int = 2) -> nx.DiGraph:
        """Extract local subgraph around a node"""
        node_id = node_id.lower().strip()
        
        if not self.graph.has_node(node_id):
            return nx.DiGraph()
        
        # BFS to depth
        nodes = {node_id}
        frontier = {node_id}
        
        for _ in range(depth):
            new_frontier = set()
            for node in frontier:
                new_frontier.update(self.graph.successors(node))
                new_frontier.update(self.graph.predecessors(node))
            frontier = new_frontier - nodes
            nodes.update(frontier)
        
        return self.graph.subgraph(nodes).copy()

    def get_metrics(self) -> Dict:
        """Compute graph health metrics"""
        if len(self.graph) == 0:
            return {
                'nodes': 0,
                'edges': 0,
                'density': 0.0,
                'avg_degree': 0.0,
                'connected_components': 0
            }
        
        degrees = dict(self.graph.degree())
        
        return {
            'nodes': len(self.graph.nodes),
            'edges': len(self.graph.edges),
            'density': float(nx.density(self.graph)),
            'avg_degree': sum(degrees.values()) / len(degrees) if degrees else 0.0,
            'connected_components': nx.number_weakly_connected_components(self.graph)
        }
