# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch

from llms_from_scratch.ch04 import GPTModel, GPTModelFast
from llms_from_scratch.kv_cache.gpt2 import GPTModel as GPTModelKV
from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached

import pytest
import torch
import tiktoken


GPT_CONFIG_124M = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "emb_dim": 768,          # Embedding dimension
    "n_heads": 12,           # Number of attention heads
    "n_layers": 12,          # Number of layers
    "drop_rate": 0.1,        # Dropout rate
    "qkv_bias": False        # Query-Key-Value bias
}


@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast, GPTModelKV])
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
def test_gpt_model_variants(ModelClass, generate_fn):

    # Skip incompatible combinations
    if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
        return
    if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
        return

    torch.manual_seed(123)
    model = ModelClass(GPT_CONFIG_124M)
    model.eval()  # disable dropout

    start_context = "Hello, I am"

    tokenizer = tiktoken.get_encoding("gpt2")
    encoded = tokenizer.encode(start_context)
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)

    print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
    print("\nInput text:", start_context)
    print("Encoded input text:", encoded)
    print("encoded_tensor.shape:", encoded_tensor.shape)

    out = generate_fn(
        model=model,
        idx=encoded_tensor,
        max_new_tokens=10,
        context_size=GPT_CONFIG_124M["context_length"]
    )

    expect = torch.tensor([
        [15496,   11,   314,   716, 27018, 24086, 47843, 30961, 42348,  7267,
         49706, 43231, 47062, 34657]
    ])
    assert torch.equal(expect, out), "Generated output does not match expected output"
