"""Common classes/functions for tree index operations."""

import asyncio
import logging
from typing import Dict, List, Optional, Sequence, Tuple

from llama_index.core.async_utils import run_async_tasks
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from llama_index.core.data_structs.data_structs import IndexGraph
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.indices.utils import get_sorted_node_list, truncate_text
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import BasePromptTemplate
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
from llama_index.core.settings import Settings
from llama_index.core.storage.docstore import BaseDocumentStore
from llama_index.core.storage.docstore.registry import get_default_docstore
from llama_index.core.utils import get_tqdm_iterable

logger = logging.getLogger(__name__)


class GPTTreeIndexBuilder:
    """
    GPT tree index builder.

    Helper class to build the tree-structured index,
    or to synthesize an answer.

    """

    def __init__(
        self,
        num_children: int,
        summary_prompt: BasePromptTemplate,
        llm: Optional[LLM] = None,
        docstore: Optional[BaseDocumentStore] = None,
        show_progress: bool = False,
        use_async: bool = False,
    ) -> None:
        """Initialize with params."""
        if num_children < 2:
            raise ValueError("Invalid number of children.")
        self.num_children = num_children
        self.summary_prompt = summary_prompt
        self._llm = llm or Settings.llm
        self._prompt_helper = Settings._prompt_helper or PromptHelper.from_llm_metadata(
            self._llm.metadata,
        )
        self._callback_manager = Settings.callback_manager
        self._use_async = use_async
        self._show_progress = show_progress
        self._docstore = docstore or get_default_docstore()

    @property
    def docstore(self) -> BaseDocumentStore:
        """Return docstore."""
        return self._docstore

    def build_from_nodes(
        self,
        nodes: Sequence[BaseNode],
        build_tree: bool = True,
    ) -> IndexGraph:
        """
        Build from text.

        Returns:
            IndexGraph: graph object consisting of all_nodes, root_nodes

        """
        index_graph = IndexGraph()
        for node in nodes:
            index_graph.insert(node)

        if build_tree:
            return self.build_index_from_nodes(
                index_graph, index_graph.all_nodes, index_graph.all_nodes, level=0
            )
        else:
            return index_graph

    def _prepare_node_and_text_chunks(
        self, cur_node_ids: Dict[int, str]
    ) -> Tuple[List[int], List[List[BaseNode]], List[str]]:
        """Prepare node and text chunks."""
        cur_nodes = {
            index: self._docstore.get_node(node_id)
            for index, node_id in cur_node_ids.items()
        }
        cur_node_list = get_sorted_node_list(cur_nodes)
        logger.info(
            f"> Building index from nodes: {len(cur_nodes) // self.num_children} chunks"
        )
        indices, cur_nodes_chunks, text_chunks = [], [], []
        for i in range(0, len(cur_node_list), self.num_children):
            cur_nodes_chunk = cur_node_list[i : i + self.num_children]
            truncated_chunks = self._prompt_helper.truncate(
                prompt=self.summary_prompt,
                text_chunks=[
                    node.get_content(metadata_mode=MetadataMode.LLM)
                    for node in cur_nodes_chunk
                ],
                llm=self._llm,
            )
            text_chunk = "\n".join(truncated_chunks)
            indices.append(i)
            cur_nodes_chunks.append(cur_nodes_chunk)
            text_chunks.append(text_chunk)
        return indices, cur_nodes_chunks, text_chunks

    def _construct_parent_nodes(
        self,
        index_graph: IndexGraph,
        indices: List[int],
        cur_nodes_chunks: List[List[BaseNode]],
        summaries: List[str],
    ) -> Dict[int, str]:
        """
        Construct parent nodes.

        Save nodes to docstore.

        """
        new_node_dict = {}
        for i, cur_nodes_chunk, new_summary in zip(
            indices, cur_nodes_chunks, summaries
        ):
            logger.debug(
                f"> {i}/{len(cur_nodes_chunk)}, "
                f"summary: {truncate_text(new_summary, 50)}"
            )
            new_node = TextNode(text=new_summary)
            index_graph.insert(new_node, children_nodes=cur_nodes_chunk)
            index = index_graph.get_index(new_node)
            new_node_dict[index] = new_node.node_id
            self._docstore.add_documents([new_node], allow_update=False)
        return new_node_dict

    def build_index_from_nodes(
        self,
        index_graph: IndexGraph,
        cur_node_ids: Dict[int, str],
        all_node_ids: Dict[int, str],
        level: int = 0,
    ) -> IndexGraph:
        """Consolidates chunks recursively, in a bottoms-up fashion."""
        if len(cur_node_ids) <= self.num_children:
            index_graph.root_nodes = cur_node_ids
            return index_graph

        indices, cur_nodes_chunks, text_chunks = self._prepare_node_and_text_chunks(
            cur_node_ids
        )

        with self._callback_manager.event(
            CBEventType.TREE, payload={EventPayload.CHUNKS: text_chunks}
        ) as event:
            if self._use_async:
                tasks = [
                    self._llm.apredict(self.summary_prompt, context_str=text_chunk)
                    for text_chunk in text_chunks
                ]
                outputs: List[Tuple[str, str]] = run_async_tasks(
                    tasks,
                    show_progress=self._show_progress,
                    progress_bar_desc="Generating summaries",
                )
                summaries = [output[0] for output in outputs]
            else:
                text_chunks_progress = get_tqdm_iterable(
                    text_chunks,
                    show_progress=self._show_progress,
                    desc="Generating summaries",
                )
                summaries = [
                    self._llm.predict(self.summary_prompt, context_str=text_chunk)
                    for text_chunk in text_chunks_progress
                ]

            event.on_end(payload={"summaries": summaries, "level": level})

        new_node_dict = self._construct_parent_nodes(
            index_graph, indices, cur_nodes_chunks, summaries
        )
        all_node_ids.update(new_node_dict)

        index_graph.root_nodes = new_node_dict

        if len(new_node_dict) <= self.num_children:
            return index_graph
        else:
            return self.build_index_from_nodes(
                index_graph, new_node_dict, all_node_ids, level=level + 1
            )

    async def abuild_index_from_nodes(
        self,
        index_graph: IndexGraph,
        cur_node_ids: Dict[int, str],
        all_node_ids: Dict[int, str],
        level: int = 0,
    ) -> IndexGraph:
        """Consolidates chunks recursively, in a bottoms-up fashion."""
        if len(cur_node_ids) <= self.num_children:
            index_graph.root_nodes = cur_node_ids
            return index_graph

        indices, cur_nodes_chunks, text_chunks = self._prepare_node_and_text_chunks(
            cur_node_ids
        )

        with self._callback_manager.event(
            CBEventType.TREE, payload={EventPayload.CHUNKS: text_chunks}
        ) as event:
            text_chunks_progress = get_tqdm_iterable(
                text_chunks,
                show_progress=self._show_progress,
                desc="Generating summaries",
            )
            tasks = [
                self._llm.apredict(self.summary_prompt, context_str=text_chunk)
                for text_chunk in text_chunks_progress
            ]
            summaries = await asyncio.gather(*tasks)

            event.on_end(payload={"summaries": summaries, "level": level})

        new_node_dict = self._construct_parent_nodes(
            index_graph, indices, cur_nodes_chunks, summaries
        )
        all_node_ids.update(new_node_dict)

        index_graph.root_nodes = new_node_dict

        if len(new_node_dict) <= self.num_children:
            return index_graph
        else:
            return await self.abuild_index_from_nodes(
                index_graph, new_node_dict, all_node_ids, level=level + 1
            )
